Skip to content

Commit beeef7d

Browse files
plun1331pre-commit-ci[bot]Lulalaby
authored
feat: subscriptions & related changes (#2564)
* feat: subscriptions & related changes * style(pre-commit): auto fixes from pre-commit.com hooks * feat: changelog * feat: changelog * style(pre-commit): auto fixes from pre-commit.com hooks * fix: move abc import to TYPE_CHECKING * style(pre-commit): auto fixes from pre-commit.com hooks * fix: circular import from Entitlement * docs: correct directives from notice to note * Update discord/enums.py Signed-off-by: Lala Sabathil <[email protected]> * despite what it looks like, this is a lazily written commit message * despite what it looks like, this is a lazily written commit message 2 * fix bugs * style(pre-commit): auto fixes from pre-commit.com hooks * Update CHANGELOG.md Signed-off-by: plun1331 <[email protected]> * Update iterators.py Signed-off-by: plun1331 <[email protected]> * apply review suggestions * Update monetization.py Signed-off-by: plun1331 <[email protected]> * Update monetization.py Signed-off-by: plun1331 <[email protected]> --------- Signed-off-by: Lala Sabathil <[email protected]> Signed-off-by: plun1331 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lala Sabathil <[email protected]>
1 parent 19721e2 commit beeef7d

File tree

11 files changed

+447
-54
lines changed

11 files changed

+447
-54
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ These changes are available on the `master` branch, but have not yet been releas
4747
([#2659](https://github.com/Pycord-Development/pycord/pull/2659))
4848
- Added `VoiceMessage` subclass of `File` to allow voice messages to be sent.
4949
([#2579](https://github.com/Pycord-Development/pycord/pull/2579))
50+
- Added new `Subscription` object and related methods/events.
51+
([#2564](https://github.com/Pycord-Development/pycord/pull/2564))
5052

5153
### Fixed
5254

@@ -104,6 +106,8 @@ These changes are available on the `master` branch, but have not yet been releas
104106
([#2176](https://github.com/Pycord-Development/pycord/pull/2176))
105107
- Updated `Guild.filesize_limit` to 10 MB instead of 25 MB following Discord's API
106108
changes. ([#2671](https://github.com/Pycord-Development/pycord/pull/2671))
109+
- `Entitlement.ends_at` can now be `None`.
110+
([#2564](https://github.com/Pycord-Development/pycord/pull/2564))
107111

108112
### Deprecated
109113

@@ -112,6 +116,11 @@ These changes are available on the `master` branch, but have not yet been releas
112116
- Deprecated `Emoji` in favor of `GuildEmoji`.
113117
([#2501](https://github.com/Pycord-Development/pycord/pull/2501))
114118

119+
### Fixed
120+
121+
- Fixed `AttributeError` when trying to consume a consumable entitlement.
122+
([#2564](https://github.com/Pycord-Development/pycord/pull/2564))
123+
115124
## [2.6.1] - 2024-09-15
116125

117126
### Fixed

discord/client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2087,7 +2087,7 @@ async def fetch_skus(self) -> list[SKU]:
20872087
The bot's SKUs.
20882088
"""
20892089
data = await self._connection.http.list_skus(self.application_id)
2090-
return [SKU(data=s) for s in data]
2090+
return [SKU(state=self._connection, data=s) for s in data]
20912091

20922092
def entitlements(
20932093
self,

discord/enums.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,6 +1055,14 @@ class PollLayoutType(Enum):
10551055
default = 1
10561056

10571057

1058+
class SubscriptionStatus(Enum):
1059+
"""The status of a subscription."""
1060+
1061+
active = 0
1062+
ending = 1
1063+
inactive = 2
1064+
1065+
10581066
T = TypeVar("T")
10591067

10601068

discord/http.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3081,6 +3081,43 @@ def delete_test_entitlement(
30813081
)
30823082
return self.request(r)
30833083

3084+
def list_sku_subscriptions(
3085+
self,
3086+
sku_id: Snowflake,
3087+
*,
3088+
before: Snowflake | None = None,
3089+
after: Snowflake | None = None,
3090+
limit: int = 50,
3091+
user_id: Snowflake | None = None,
3092+
) -> Response[list[monetization.Subscription]]:
3093+
params: dict[str, Any] = {}
3094+
if before is not None:
3095+
params["before"] = before
3096+
if after is not None:
3097+
params["after"] = after
3098+
if limit is not None:
3099+
params["limit"] = limit
3100+
if user_id is not None:
3101+
params["user_id"] = user_id
3102+
return self.request(
3103+
Route("GET", "/skus/{sku_id}/subscriptions", sku_id=sku_id),
3104+
params=params,
3105+
)
3106+
3107+
def get_subscription(
3108+
self,
3109+
sku_id: Snowflake,
3110+
subscription_id: Snowflake,
3111+
) -> Response[monetization.Subscription]:
3112+
return self.request(
3113+
Route(
3114+
"GET",
3115+
"/skus/{sku_id}/subscriptions/{subscription_id}",
3116+
sku_id=sku_id,
3117+
subscription_id=subscription_id,
3118+
)
3119+
)
3120+
30843121
# Onboarding
30853122

30863123
def get_onboarding(self, guild_id: Snowflake) -> Response[onboarding.Onboarding]:

discord/iterators.py

Lines changed: 112 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040

4141
from .audit_logs import AuditLogEntry
4242
from .errors import NoMoreItems
43-
from .monetization import Entitlement
4443
from .object import Object
4544
from .utils import maybe_coroutine, snowflake_time, time_snowflake
4645

@@ -52,19 +51,22 @@
5251
"MemberIterator",
5352
"ScheduledEventSubscribersIterator",
5453
"EntitlementIterator",
54+
"SubscriptionIterator",
5555
)
5656

5757
if TYPE_CHECKING:
5858
from .abc import Snowflake
5959
from .guild import BanEntry, Guild
6060
from .member import Member
6161
from .message import Message
62+
from .monetization import Entitlement, Subscription
6263
from .scheduled_events import ScheduledEvent
6364
from .threads import Thread
6465
from .types.audit_log import AuditLog as AuditLogPayload
6566
from .types.guild import Guild as GuildPayload
6667
from .types.message import Message as MessagePayload
6768
from .types.monetization import Entitlement as EntitlementPayload
69+
from .types.monetization import Subscription as SubscriptionPayload
6870
from .types.threads import Thread as ThreadPayload
6971
from .types.user import PartialUser as PartialUserPayload
7072
from .user import User
@@ -1031,6 +1033,11 @@ def _get_retrieve(self):
10311033
self.retrieve = r
10321034
return r > 0
10331035

1036+
def create_entitlement(self, data) -> Entitlement:
1037+
from .monetization import Entitlement
1038+
1039+
return Entitlement(data=data, state=self.state)
1040+
10341041
async def fill_entitlements(self):
10351042
if not self._get_retrieve():
10361043
return
@@ -1044,9 +1051,9 @@ async def fill_entitlements(self):
10441051
self.limit = 0 # terminate loop
10451052

10461053
for element in data:
1047-
await self.entitlements.put(Entitlement(data=element, state=self.state))
1054+
await self.entitlements.put(self.create_entitlement(element))
10481055

1049-
async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]:
1056+
async def _retrieve_entitlements(self, retrieve) -> list[EntitlementPayload]:
10501057
"""Retrieve entitlements and update next parameters."""
10511058
raise NotImplementedError
10521059

@@ -1089,3 +1096,105 @@ async def _retrieve_entitlements_after_strategy(
10891096
self.limit -= retrieve
10901097
self.after = Object(id=int(data[-1]["id"]))
10911098
return data
1099+
1100+
1101+
class SubscriptionIterator(_AsyncIterator["Subscription"]):
1102+
def __init__(
1103+
self,
1104+
state,
1105+
sku_id: int,
1106+
limit: int = None,
1107+
before: datetime.datetime | None = None,
1108+
after: datetime.datetime | None = None,
1109+
user_id: int | None = None,
1110+
):
1111+
if isinstance(before, datetime.datetime):
1112+
before = Object(id=time_snowflake(before, high=False))
1113+
if isinstance(after, datetime.datetime):
1114+
after = Object(id=time_snowflake(after, high=True))
1115+
1116+
self.state = state
1117+
self.sku_id = sku_id
1118+
self.limit = limit
1119+
self.before = before
1120+
self.after = after
1121+
self.user_id = user_id
1122+
1123+
self._filter = None
1124+
1125+
self.get_subscriptions = state.http.list_sku_subscriptions
1126+
self.subscriptions = asyncio.Queue()
1127+
1128+
if self.before and self.after:
1129+
self._retrieve_subscriptions = self._retrieve_subscriptions_before_strategy
1130+
self._filter = lambda m: int(m["id"]) > self.after.id
1131+
elif self.after:
1132+
self._retrieve_subscriptions = self._retrieve_subscriptions_after_strategy
1133+
else:
1134+
self._retrieve_subscriptions = self._retrieve_subscriptions_before_strategy
1135+
1136+
async def next(self) -> Guild:
1137+
if self.subscriptions.empty():
1138+
await self.fill_subscriptions()
1139+
1140+
try:
1141+
return self.subscriptions.get_nowait()
1142+
except asyncio.QueueEmpty:
1143+
raise NoMoreItems()
1144+
1145+
def _get_retrieve(self):
1146+
l = self.limit
1147+
if l is None or l > 100:
1148+
r = 100
1149+
else:
1150+
r = l
1151+
self.retrieve = r
1152+
return r > 0
1153+
1154+
def create_subscription(self, data) -> Subscription:
1155+
from .monetization import Subscription
1156+
1157+
return Subscription(state=self.state, data=data)
1158+
1159+
async def fill_subscriptions(self):
1160+
if self._get_retrieve():
1161+
data = await self._retrieve_subscriptions(self.retrieve)
1162+
if self.limit is None or len(data) < 100:
1163+
self.limit = 0
1164+
1165+
if self._filter:
1166+
data = filter(self._filter, data)
1167+
1168+
for element in data:
1169+
await self.subscriptions.put(self.create_subscription(element))
1170+
1171+
async def _retrieve_subscriptions(self, retrieve) -> list[SubscriptionPayload]:
1172+
raise NotImplementedError
1173+
1174+
async def _retrieve_subscriptions_before_strategy(self, retrieve):
1175+
before = self.before.id if self.before else None
1176+
data: list[SubscriptionPayload] = await self.get_subscriptions(
1177+
self.sku_id,
1178+
limit=retrieve,
1179+
before=before,
1180+
user_id=self.user_id,
1181+
)
1182+
if len(data):
1183+
if self.limit is not None:
1184+
self.limit -= retrieve
1185+
self.before = Object(id=int(data[-1]["id"]))
1186+
return data
1187+
1188+
async def _retrieve_subscriptions_after_strategy(self, retrieve):
1189+
after = self.after.id if self.after else None
1190+
data: list[SubscriptionPayload] = await self.get_subscriptions(
1191+
self.sku_id,
1192+
limit=retrieve,
1193+
after=after,
1194+
user_id=self.user_id,
1195+
)
1196+
if len(data):
1197+
if self.limit is not None:
1198+
self.limit -= retrieve
1199+
self.after = Object(id=int(data[0]["id"]))
1200+
return data

0 commit comments

Comments
 (0)