Skip to content

Commit 6da8e1f

Browse files
authored
Move listener to postgres (#738)
Move the listener built in cog to use postgres instead of mongo as a database
1 parent f1947b1 commit 6da8e1f

File tree

2 files changed

+99
-52
lines changed

2 files changed

+99
-52
lines changed

techsupport_bot/base/databases.py

+10
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ class Warning(bot.db.Model):
136136
reason = bot.db.Column(bot.db.String)
137137
time = bot.db.Column(bot.db.DateTime, default=datetime.datetime.utcnow)
138138

139+
class Listener(bot.db.Model):
140+
"""The postgres table for listeners
141+
Currently used in listen.py"""
142+
143+
__tablename__ = "listeners"
144+
pk = bot.db.Column(bot.db.Integer, primary_key=True)
145+
src_id = bot.db.Column(bot.db.String)
146+
dst_id = bot.db.Column(bot.db.String)
147+
139148
class Rule(bot.db.Model):
140149
"""The postgres table for rules
141150
Currently used in rules.py"""
@@ -154,4 +163,5 @@ class Rule(bot.db.Model):
154163
bot.models.IRCChannelMapping = IRCChannelMapping
155164
bot.models.UserNote = UserNote
156165
bot.models.Warning = Warning
166+
bot.models.Listener = Listener
157167
bot.models.Rule = Rule

techsupport_bot/cogs/listen.py

+89-52
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
"""Module for channel listening.
22
"""
3+
from __future__ import annotations
34

45
import datetime
6+
from typing import TYPE_CHECKING
57

68
import discord
79
import expiringdict
810
import ui
911
from base import auxiliary, cogs
1012
from discord.ext import commands
1113

14+
if TYPE_CHECKING:
15+
import bot
16+
1217

1318
class ListenChannel(commands.Converter):
1419
"""Converter for grabbing a channel via the API.
@@ -72,16 +77,13 @@ class Listener(cogs.BaseCog):
7277
ADMIN_ONLY = True
7378
MAX_DESTINATIONS = 10
7479
CACHE_TIME = 60
75-
COLLECTION_NAME = "listener"
7680

7781
async def preconfig(self):
7882
"""Preconfigures the listener cog."""
7983
self.destination_cache = expiringdict.ExpiringDict(
8084
max_len=1000,
8185
max_age_seconds=1200,
8286
)
83-
if not self.COLLECTION_NAME in await self.bot.mongo.list_collection_names():
84-
await self.bot.mongo.create_collection(self.COLLECTION_NAME)
8587

8688
async def get_destinations(self, src):
8789
"""Gets channel object destinations for a given source channel.
@@ -104,17 +106,18 @@ async def build_destinations_from_src(self, src):
104106
src (discord.TextChannel): the source channel to build for
105107
"""
106108
destination_data = await self.get_destination_data(src)
107-
destination_ids = (
108-
destination_data.get("destinations", []) if destination_data else []
109-
)
110-
destinations = await self.build_destinations(destination_ids)
109+
if not destination_data:
110+
return None
111+
destinations = await self.build_destinations(destination_data)
111112
return destinations
112113

113-
async def build_destinations(self, destination_ids):
114+
async def build_destinations(
115+
self, destination_ids: list[int]
116+
) -> list[discord.abc.Messageable]:
114117
"""Converts destination ID's to their actual channels objects.
115118
116119
parameters:
117-
destination_ids ([int]): the destination ID's to reference
120+
destination_ids (list[int]): the destination ID's to reference
118121
"""
119122
destinations = set()
120123
for did in destination_ids:
@@ -132,34 +135,77 @@ async def build_destinations(self, destination_ids):
132135

133136
return destinations
134137

135-
async def get_destination_data(self, src):
138+
async def get_destination_data(self, src: discord.TextChannel) -> list[str]:
136139
"""Retrieves raw destination data given a source channel.
137140
138141
parameters:
139142
src (discord.TextChannel): the source channel to build for
140143
"""
141-
destination_data = await self.bot.mongo[self.COLLECTION_NAME].find_one(
142-
{"source_id": {"$eq": str(src.id)}}
144+
destination_data = await self.bot.models.Listener.query.where(
145+
self.bot.models.Listener.src_id == str(src.id)
146+
).gino.all()
147+
if not destination_data:
148+
return None
149+
150+
return [listener.dst_id for listener in destination_data]
151+
152+
def build_list_of_sources(
153+
self, listeners: list[bot.db.model.Listener]
154+
) -> list[str]:
155+
"""Builds a list of unique sources from the raw database output
156+
157+
Args:
158+
listeners (list[bot.db.model.Listener]): The entire database dumped into a list
159+
160+
Returns:
161+
list[str]: The list of unique src channel strings
162+
"""
163+
src_id_list = [listener.src_id for listener in listeners]
164+
final_list = list(set(src_id_list))
165+
return final_list
166+
167+
async def get_specific_listener(
168+
self, src: discord.TextChannel, dst: discord.TextChannel
169+
) -> bot.db.models.Listener:
170+
"""Gets a database object of the given listener pair
171+
172+
Args:
173+
src (discord.TextChannel): The source channel
174+
dst (discord.TextChannel): The destination channel
175+
176+
Returns:
177+
bot.db.models.Listener: The db object, if the listener exists
178+
"""
179+
listener = (
180+
await self.bot.models.Listener.query.where(
181+
self.bot.models.Listener.src_id == str(src.id)
182+
)
183+
.where(self.bot.models.Listener.dst_id == str(dst.id))
184+
.gino.first()
143185
)
144-
return destination_data
186+
return listener
145187

146188
async def get_all_sources(self):
147189
"""Gets all source data.
148190
149191
This is kind of expensive, so use lightly.
150192
"""
151193
source_objects = []
152-
cursor = self.bot.mongo[self.COLLECTION_NAME].find({})
153-
for doc in await cursor.to_list(length=50):
154-
src_ch = self.bot.get_channel(int(doc.get("source_id"), 0))
194+
all_listens = await self.bot.models.Listener.query.gino.all()
195+
source_list = self.build_list_of_sources(all_listens)
196+
for src in source_list:
197+
src_ch = self.bot.get_channel(int(src))
155198
if not src_ch:
156199
continue
157200

158-
destination_ids = doc.get("destinations")
159-
if not destination_ids:
201+
destination_ids = await self.bot.models.Listener.query.where(
202+
self.bot.models.Listener.src_id == src
203+
).gino.all()
204+
dst_id_list = [listener.dst_id for listener in destination_ids]
205+
if not dst_id_list:
160206
continue
161207

162-
destinations = await self.build_destinations(destination_ids)
208+
destinations = await self.build_destinations(dst_id_list)
163209
if not destinations:
164210
continue
165211

@@ -169,18 +215,20 @@ async def get_all_sources(self):
169215

170216
return source_objects
171217

172-
async def update_destinations(self, src, destination_ids):
218+
async def update_destinations(
219+
self, src: discord.TextChannel, dst: discord.TextChannel
220+
) -> None:
173221
"""Updates destinations in Mongo given a src.
174222
175223
parameters:
176224
src (discord.TextChannel): the source channel to build for
177-
destination_ids ([int]): the destination ID's to reference
225+
dst (discord.TextChannel): the destination channel to build for
178226
"""
179-
as_str = str(src.id)
180-
new_data = {"source_id": as_str, "destinations": list(set(destination_ids))}
181-
await self.bot.mongo[self.COLLECTION_NAME].replace_one(
182-
{"source_id": as_str}, new_data, upsert=True
227+
new_listener = self.bot.models.Listener(
228+
src_id=str(src.id),
229+
dst_id=str(dst.id),
183230
)
231+
await new_listener.create()
184232
try:
185233
del self.destination_cache[src.id]
186234
except KeyError:
@@ -251,13 +299,15 @@ def get_help_embed(self, command_prefix):
251299
@listen.command(
252300
description="Starts a listening job", usage="[src-channel] [dst-channel]"
253301
)
254-
async def start(self, ctx, src: ListenChannel, dst: ListenChannel):
302+
async def start(
303+
self, ctx: commands.Context, src: ListenChannel, dst: ListenChannel
304+
):
255305
"""Executes a start-listening command.
256306
257307
This is a command and should be accessed via Discord.
258308
259309
parameters:
260-
ctx (discord.ext.Context): the context object for the message
310+
ctx (commands.Context): the context object for the message
261311
src (ListenChannel): the source channel ID
262312
dst (ListenChannel): the destination channel ID
263313
"""
@@ -268,26 +318,14 @@ async def start(self, ctx, src: ListenChannel, dst: ListenChannel):
268318
)
269319
return
270320

271-
destination_data = await self.get_destination_data(src)
272-
destinations = (
273-
destination_data.get("destinations", []) if destination_data else []
274-
)
275-
276-
if str(dst.id) in destinations:
321+
listener_object = await self.get_specific_listener(src, dst)
322+
if listener_object:
277323
await auxiliary.send_deny_embed(
278324
message="That source and destination already exist", channel=ctx.channel
279325
)
280326
return
281327

282-
if len(destinations) > self.MAX_DESTINATIONS:
283-
await auxiliary.send_deny_embed(
284-
message="There are too many destinations for that source",
285-
channel=ctx.channel,
286-
)
287-
return
288-
289-
destinations.append(str(dst.id))
290-
await self.update_destinations(src, destinations)
328+
await self.update_destinations(src, dst)
291329

292330
await auxiliary.send_confirm_embed(
293331
message="Listening registered!", channel=ctx.channel
@@ -313,19 +351,14 @@ async def stop(self, ctx, src: ListenChannel, dst: ListenChannel):
313351
)
314352
return
315353

316-
destination_data = await self.get_destination_data(src)
317-
destinations = (
318-
destination_data.get("destinations", []) if destination_data else []
319-
)
320-
if str(dst.id) not in destinations:
354+
listener_object = await self.get_specific_listener(src, dst)
355+
if not listener_object:
321356
await auxiliary.send_deny_embed(
322357
message="That destination is not registered with that source",
323358
channel=ctx.channel,
324359
)
325360
return
326-
327-
destinations.remove(str(dst.id))
328-
await self.update_destinations(src, destinations)
361+
await listener_object.delete()
329362

330363
await auxiliary.send_confirm_embed(
331364
message="Listening deregistered!", channel=ctx.channel
@@ -342,7 +375,9 @@ async def clear(self, ctx):
342375
parameters:
343376
ctx (discord.ext.Context): the context object for the message
344377
"""
345-
await self.bot.mongo[self.COLLECTION_NAME].delete_many({})
378+
all_listens = await self.bot.models.Listener.query.gino.all()
379+
for listener in all_listens:
380+
await listener.delete()
346381
self.destination_cache.clear()
347382

348383
await auxiliary.send_confirm_embed(
@@ -389,7 +424,7 @@ async def jobs(self, ctx):
389424
await ctx.send(embed=embed)
390425

391426
@commands.Cog.listener()
392-
async def on_message(self, message):
427+
async def on_message(self, message: discord.Message):
393428
"""Listens to message events.
394429
395430
parameters:
@@ -400,6 +435,8 @@ async def on_message(self, message):
400435
if isinstance(message.channel, discord.DMChannel):
401436
return
402437
destinations = await self.get_destinations(message.channel)
438+
if not destinations:
439+
return
403440
for dst in destinations:
404441
embed = MessageEmbed(message=message)
405442
await dst.send(embed=embed)

0 commit comments

Comments
 (0)