1
1
"""Module for channel listening.
2
2
"""
3
+ from __future__ import annotations
3
4
4
5
import datetime
6
+ from typing import TYPE_CHECKING
5
7
6
8
import discord
7
9
import expiringdict
8
10
import ui
9
11
from base import auxiliary , cogs
10
12
from discord .ext import commands
11
13
14
+ if TYPE_CHECKING :
15
+ import bot
16
+
12
17
13
18
class ListenChannel (commands .Converter ):
14
19
"""Converter for grabbing a channel via the API.
@@ -72,16 +77,13 @@ class Listener(cogs.BaseCog):
72
77
ADMIN_ONLY = True
73
78
MAX_DESTINATIONS = 10
74
79
CACHE_TIME = 60
75
- COLLECTION_NAME = "listener"
76
80
77
81
async def preconfig (self ):
78
82
"""Preconfigures the listener cog."""
79
83
self .destination_cache = expiringdict .ExpiringDict (
80
84
max_len = 1000 ,
81
85
max_age_seconds = 1200 ,
82
86
)
83
- if not self .COLLECTION_NAME in await self .bot .mongo .list_collection_names ():
84
- await self .bot .mongo .create_collection (self .COLLECTION_NAME )
85
87
86
88
async def get_destinations (self , src ):
87
89
"""Gets channel object destinations for a given source channel.
@@ -104,17 +106,18 @@ async def build_destinations_from_src(self, src):
104
106
src (discord.TextChannel): the source channel to build for
105
107
"""
106
108
destination_data = await self .get_destination_data (src )
107
- destination_ids = (
108
- destination_data .get ("destinations" , []) if destination_data else []
109
- )
110
- destinations = await self .build_destinations (destination_ids )
109
+ if not destination_data :
110
+ return None
111
+ destinations = await self .build_destinations (destination_data )
111
112
return destinations
112
113
113
- async def build_destinations (self , destination_ids ):
114
+ async def build_destinations (
115
+ self , destination_ids : list [int ]
116
+ ) -> list [discord .abc .Messageable ]:
114
117
"""Converts destination ID's to their actual channels objects.
115
118
116
119
parameters:
117
- destination_ids ([int]): the destination ID's to reference
120
+ destination_ids (list [int]): the destination ID's to reference
118
121
"""
119
122
destinations = set ()
120
123
for did in destination_ids :
@@ -132,34 +135,77 @@ async def build_destinations(self, destination_ids):
132
135
133
136
return destinations
134
137
135
- async def get_destination_data (self , src ) :
138
+ async def get_destination_data (self , src : discord . TextChannel ) -> list [ str ] :
136
139
"""Retrieves raw destination data given a source channel.
137
140
138
141
parameters:
139
142
src (discord.TextChannel): the source channel to build for
140
143
"""
141
- destination_data = await self .bot .mongo [self .COLLECTION_NAME ].find_one (
142
- {"source_id" : {"$eq" : str (src .id )}}
144
+ destination_data = await self .bot .models .Listener .query .where (
145
+ self .bot .models .Listener .src_id == str (src .id )
146
+ ).gino .all ()
147
+ if not destination_data :
148
+ return None
149
+
150
+ return [listener .dst_id for listener in destination_data ]
151
+
152
+ def build_list_of_sources (
153
+ self , listeners : list [bot .db .model .Listener ]
154
+ ) -> list [str ]:
155
+ """Builds a list of unique sources from the raw database output
156
+
157
+ Args:
158
+ listeners (list[bot.db.model.Listener]): The entire database dumped into a list
159
+
160
+ Returns:
161
+ list[str]: The list of unique src channel strings
162
+ """
163
+ src_id_list = [listener .src_id for listener in listeners ]
164
+ final_list = list (set (src_id_list ))
165
+ return final_list
166
+
167
+ async def get_specific_listener (
168
+ self , src : discord .TextChannel , dst : discord .TextChannel
169
+ ) -> bot .db .models .Listener :
170
+ """Gets a database object of the given listener pair
171
+
172
+ Args:
173
+ src (discord.TextChannel): The source channel
174
+ dst (discord.TextChannel): The destination channel
175
+
176
+ Returns:
177
+ bot.db.models.Listener: The db object, if the listener exists
178
+ """
179
+ listener = (
180
+ await self .bot .models .Listener .query .where (
181
+ self .bot .models .Listener .src_id == str (src .id )
182
+ )
183
+ .where (self .bot .models .Listener .dst_id == str (dst .id ))
184
+ .gino .first ()
143
185
)
144
- return destination_data
186
+ return listener
145
187
146
188
async def get_all_sources (self ):
147
189
"""Gets all source data.
148
190
149
191
This is kind of expensive, so use lightly.
150
192
"""
151
193
source_objects = []
152
- cursor = self .bot .mongo [self .COLLECTION_NAME ].find ({})
153
- for doc in await cursor .to_list (length = 50 ):
154
- src_ch = self .bot .get_channel (int (doc .get ("source_id" ), 0 ))
194
+ all_listens = await self .bot .models .Listener .query .gino .all ()
195
+ source_list = self .build_list_of_sources (all_listens )
196
+ for src in source_list :
197
+ src_ch = self .bot .get_channel (int (src ))
155
198
if not src_ch :
156
199
continue
157
200
158
- destination_ids = doc .get ("destinations" )
159
- if not destination_ids :
201
+ destination_ids = await self .bot .models .Listener .query .where (
202
+ self .bot .models .Listener .src_id == src
203
+ ).gino .all ()
204
+ dst_id_list = [listener .dst_id for listener in destination_ids ]
205
+ if not dst_id_list :
160
206
continue
161
207
162
- destinations = await self .build_destinations (destination_ids )
208
+ destinations = await self .build_destinations (dst_id_list )
163
209
if not destinations :
164
210
continue
165
211
@@ -169,18 +215,20 @@ async def get_all_sources(self):
169
215
170
216
return source_objects
171
217
172
- async def update_destinations (self , src , destination_ids ):
218
+ async def update_destinations (
219
+ self , src : discord .TextChannel , dst : discord .TextChannel
220
+ ) -> None :
173
221
"""Updates destinations in Mongo given a src.
174
222
175
223
parameters:
176
224
src (discord.TextChannel): the source channel to build for
177
- destination_ids ([int] ): the destination ID's to reference
225
+ dst (discord.TextChannel ): the destination channel to build for
178
226
"""
179
- as_str = str (src .id )
180
- new_data = {"source_id" : as_str , "destinations" : list (set (destination_ids ))}
181
- await self .bot .mongo [self .COLLECTION_NAME ].replace_one (
182
- {"source_id" : as_str }, new_data , upsert = True
227
+ new_listener = self .bot .models .Listener (
228
+ src_id = str (src .id ),
229
+ dst_id = str (dst .id ),
183
230
)
231
+ await new_listener .create ()
184
232
try :
185
233
del self .destination_cache [src .id ]
186
234
except KeyError :
@@ -251,13 +299,15 @@ def get_help_embed(self, command_prefix):
251
299
@listen .command (
252
300
description = "Starts a listening job" , usage = "[src-channel] [dst-channel]"
253
301
)
254
- async def start (self , ctx , src : ListenChannel , dst : ListenChannel ):
302
+ async def start (
303
+ self , ctx : commands .Context , src : ListenChannel , dst : ListenChannel
304
+ ):
255
305
"""Executes a start-listening command.
256
306
257
307
This is a command and should be accessed via Discord.
258
308
259
309
parameters:
260
- ctx (discord.ext .Context): the context object for the message
310
+ ctx (commands .Context): the context object for the message
261
311
src (ListenChannel): the source channel ID
262
312
dst (ListenChannel): the destination channel ID
263
313
"""
@@ -268,26 +318,14 @@ async def start(self, ctx, src: ListenChannel, dst: ListenChannel):
268
318
)
269
319
return
270
320
271
- destination_data = await self .get_destination_data (src )
272
- destinations = (
273
- destination_data .get ("destinations" , []) if destination_data else []
274
- )
275
-
276
- if str (dst .id ) in destinations :
321
+ listener_object = await self .get_specific_listener (src , dst )
322
+ if listener_object :
277
323
await auxiliary .send_deny_embed (
278
324
message = "That source and destination already exist" , channel = ctx .channel
279
325
)
280
326
return
281
327
282
- if len (destinations ) > self .MAX_DESTINATIONS :
283
- await auxiliary .send_deny_embed (
284
- message = "There are too many destinations for that source" ,
285
- channel = ctx .channel ,
286
- )
287
- return
288
-
289
- destinations .append (str (dst .id ))
290
- await self .update_destinations (src , destinations )
328
+ await self .update_destinations (src , dst )
291
329
292
330
await auxiliary .send_confirm_embed (
293
331
message = "Listening registered!" , channel = ctx .channel
@@ -313,19 +351,14 @@ async def stop(self, ctx, src: ListenChannel, dst: ListenChannel):
313
351
)
314
352
return
315
353
316
- destination_data = await self .get_destination_data (src )
317
- destinations = (
318
- destination_data .get ("destinations" , []) if destination_data else []
319
- )
320
- if str (dst .id ) not in destinations :
354
+ listener_object = await self .get_specific_listener (src , dst )
355
+ if not listener_object :
321
356
await auxiliary .send_deny_embed (
322
357
message = "That destination is not registered with that source" ,
323
358
channel = ctx .channel ,
324
359
)
325
360
return
326
-
327
- destinations .remove (str (dst .id ))
328
- await self .update_destinations (src , destinations )
361
+ await listener_object .delete ()
329
362
330
363
await auxiliary .send_confirm_embed (
331
364
message = "Listening deregistered!" , channel = ctx .channel
@@ -342,7 +375,9 @@ async def clear(self, ctx):
342
375
parameters:
343
376
ctx (discord.ext.Context): the context object for the message
344
377
"""
345
- await self .bot .mongo [self .COLLECTION_NAME ].delete_many ({})
378
+ all_listens = await self .bot .models .Listener .query .gino .all ()
379
+ for listener in all_listens :
380
+ await listener .delete ()
346
381
self .destination_cache .clear ()
347
382
348
383
await auxiliary .send_confirm_embed (
@@ -389,7 +424,7 @@ async def jobs(self, ctx):
389
424
await ctx .send (embed = embed )
390
425
391
426
@commands .Cog .listener ()
392
- async def on_message (self , message ):
427
+ async def on_message (self , message : discord . Message ):
393
428
"""Listens to message events.
394
429
395
430
parameters:
@@ -400,6 +435,8 @@ async def on_message(self, message):
400
435
if isinstance (message .channel , discord .DMChannel ):
401
436
return
402
437
destinations = await self .get_destinations (message .channel )
438
+ if not destinations :
439
+ return
403
440
for dst in destinations :
404
441
embed = MessageEmbed (message = message )
405
442
await dst .send (embed = embed )
0 commit comments