40
40
41
41
from .audit_logs import AuditLogEntry
42
42
from .errors import NoMoreItems
43
- from .monetization import Entitlement
44
43
from .object import Object
45
44
from .utils import maybe_coroutine , snowflake_time , time_snowflake
46
45
52
51
"MemberIterator" ,
53
52
"ScheduledEventSubscribersIterator" ,
54
53
"EntitlementIterator" ,
54
+ "SubscriptionIterator" ,
55
55
)
56
56
57
57
if TYPE_CHECKING :
58
58
from .abc import Snowflake
59
59
from .guild import BanEntry , Guild
60
60
from .member import Member
61
61
from .message import Message
62
+ from .monetization import Entitlement , Subscription
62
63
from .scheduled_events import ScheduledEvent
63
64
from .threads import Thread
64
65
from .types .audit_log import AuditLog as AuditLogPayload
65
66
from .types .guild import Guild as GuildPayload
66
67
from .types .message import Message as MessagePayload
67
68
from .types .monetization import Entitlement as EntitlementPayload
69
+ from .types .monetization import Subscription as SubscriptionPayload
68
70
from .types .threads import Thread as ThreadPayload
69
71
from .types .user import PartialUser as PartialUserPayload
70
72
from .user import User
@@ -1031,6 +1033,11 @@ def _get_retrieve(self):
1031
1033
self .retrieve = r
1032
1034
return r > 0
1033
1035
1036
+ def create_entitlement (self , data ) -> Entitlement :
1037
+ from .monetization import Entitlement
1038
+
1039
+ return Entitlement (data = data , state = self .state )
1040
+
1034
1041
async def fill_entitlements (self ):
1035
1042
if not self ._get_retrieve ():
1036
1043
return
@@ -1044,9 +1051,9 @@ async def fill_entitlements(self):
1044
1051
self .limit = 0 # terminate loop
1045
1052
1046
1053
for element in data :
1047
- await self .entitlements .put (Entitlement ( data = element , state = self .state ))
1054
+ await self .entitlements .put (self .create_entitlement ( element ))
1048
1055
1049
- async def _retrieve_entitlements (self , retrieve ) -> list [Entitlement ]:
1056
+ async def _retrieve_entitlements (self , retrieve ) -> list [EntitlementPayload ]:
1050
1057
"""Retrieve entitlements and update next parameters."""
1051
1058
raise NotImplementedError
1052
1059
@@ -1089,3 +1096,105 @@ async def _retrieve_entitlements_after_strategy(
1089
1096
self .limit -= retrieve
1090
1097
self .after = Object (id = int (data [- 1 ]["id" ]))
1091
1098
return data
1099
+
1100
+
1101
+ class SubscriptionIterator (_AsyncIterator ["Subscription" ]):
1102
+ def __init__ (
1103
+ self ,
1104
+ state ,
1105
+ sku_id : int ,
1106
+ limit : int = None ,
1107
+ before : datetime .datetime | None = None ,
1108
+ after : datetime .datetime | None = None ,
1109
+ user_id : int | None = None ,
1110
+ ):
1111
+ if isinstance (before , datetime .datetime ):
1112
+ before = Object (id = time_snowflake (before , high = False ))
1113
+ if isinstance (after , datetime .datetime ):
1114
+ after = Object (id = time_snowflake (after , high = True ))
1115
+
1116
+ self .state = state
1117
+ self .sku_id = sku_id
1118
+ self .limit = limit
1119
+ self .before = before
1120
+ self .after = after
1121
+ self .user_id = user_id
1122
+
1123
+ self ._filter = None
1124
+
1125
+ self .get_subscriptions = state .http .list_sku_subscriptions
1126
+ self .subscriptions = asyncio .Queue ()
1127
+
1128
+ if self .before and self .after :
1129
+ self ._retrieve_subscriptions = self ._retrieve_subscriptions_before_strategy
1130
+ self ._filter = lambda m : int (m ["id" ]) > self .after .id
1131
+ elif self .after :
1132
+ self ._retrieve_subscriptions = self ._retrieve_subscriptions_after_strategy
1133
+ else :
1134
+ self ._retrieve_subscriptions = self ._retrieve_subscriptions_before_strategy
1135
+
1136
+ async def next (self ) -> Guild :
1137
+ if self .subscriptions .empty ():
1138
+ await self .fill_subscriptions ()
1139
+
1140
+ try :
1141
+ return self .subscriptions .get_nowait ()
1142
+ except asyncio .QueueEmpty :
1143
+ raise NoMoreItems ()
1144
+
1145
+ def _get_retrieve (self ):
1146
+ l = self .limit
1147
+ if l is None or l > 100 :
1148
+ r = 100
1149
+ else :
1150
+ r = l
1151
+ self .retrieve = r
1152
+ return r > 0
1153
+
1154
+ def create_subscription (self , data ) -> Subscription :
1155
+ from .monetization import Subscription
1156
+
1157
+ return Subscription (state = self .state , data = data )
1158
+
1159
+ async def fill_subscriptions (self ):
1160
+ if self ._get_retrieve ():
1161
+ data = await self ._retrieve_subscriptions (self .retrieve )
1162
+ if self .limit is None or len (data ) < 100 :
1163
+ self .limit = 0
1164
+
1165
+ if self ._filter :
1166
+ data = filter (self ._filter , data )
1167
+
1168
+ for element in data :
1169
+ await self .subscriptions .put (self .create_subscription (element ))
1170
+
1171
+ async def _retrieve_subscriptions (self , retrieve ) -> list [SubscriptionPayload ]:
1172
+ raise NotImplementedError
1173
+
1174
+ async def _retrieve_subscriptions_before_strategy (self , retrieve ):
1175
+ before = self .before .id if self .before else None
1176
+ data : list [SubscriptionPayload ] = await self .get_subscriptions (
1177
+ self .sku_id ,
1178
+ limit = retrieve ,
1179
+ before = before ,
1180
+ user_id = self .user_id ,
1181
+ )
1182
+ if len (data ):
1183
+ if self .limit is not None :
1184
+ self .limit -= retrieve
1185
+ self .before = Object (id = int (data [- 1 ]["id" ]))
1186
+ return data
1187
+
1188
+ async def _retrieve_subscriptions_after_strategy (self , retrieve ):
1189
+ after = self .after .id if self .after else None
1190
+ data : list [SubscriptionPayload ] = await self .get_subscriptions (
1191
+ self .sku_id ,
1192
+ limit = retrieve ,
1193
+ after = after ,
1194
+ user_id = self .user_id ,
1195
+ )
1196
+ if len (data ):
1197
+ if self .limit is not None :
1198
+ self .limit -= retrieve
1199
+ self .after = Object (id = int (data [0 ]["id" ]))
1200
+ return data
0 commit comments