Skip to content

Commit

Permalink
stop using cursed default collection types
Browse files Browse the repository at this point in the history
deque drops messages when full, asyncio.Queue can never be
safely closed
  • Loading branch information
Tjstretchalot committed Jan 12, 2025
1 parent 74609f5 commit c495506
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 88 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "lonelypss"
version = "0.0.5"
version = "0.0.6"
description = "PubSub over HTTP"
readme = "README.md"
authors = [
Expand All @@ -22,7 +22,7 @@ classifiers = [
dependencies = [
"fastapi>=0.115",
"aiohttp>=3.11",
"lonelypsp>=0.0.24"
"lonelypsp>=0.0.26"
]
requires-python = ">=3.9"

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ colorama==0.4.6
fastapi==0.115.5
frozenlist==1.5.0
idna==3.10
lonelypsp==0.0.24
lonelypsp==0.0.26
multidict==6.1.0
mypy==1.13.0
mypy-extensions==1.0.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import tempfile
from typing import Union

from lonelypsp.util.bounded_deque import BoundedDequeFullError

from lonelypss.ws.handlers.open.check_result import CheckResult
from lonelypss.ws.handlers.open.senders.send_any import send_any
from lonelypss.ws.state import (
Expand All @@ -15,27 +17,34 @@
)


async def check_internal_message_task(state: StateOpen) -> CheckResult:
async def check_my_receiver_queue(state: StateOpen) -> CheckResult:
"""Makes progress using the result of the read task, if possible. Raises
an exception to indicate that we should begin the cleanup and shutdown
process
"""
if not state.internal_message_task.done():
if state.my_receiver.queue.empty():
return CheckResult.CONTINUE

result = state.internal_message_task.result()
state.internal_message_task = asyncio.create_task(state.my_receiver.queue.get())
result = state.my_receiver.queue.get_nowait()

if state.send_task is None and not state.unsent_messages:
state.send_task = asyncio.create_task(send_any(state, result))
return CheckResult.RESTART

try:
state.unsent_messages.ensure_space_for(1)
except BoundedDequeFullError:
if result.type == InternalMessageType.LARGE:
result.finished.set()
raise

if result.type != InternalMessageType.LARGE:
state.unsent_messages.append(result)
return CheckResult.RESTART

spooled = _spool_large_message(state, result)
state.unsent_messages.append(spooled)

return CheckResult.RESTART


Expand Down
131 changes: 67 additions & 64 deletions src/lonelypss/ws/handlers/open/handler.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import asyncio
from collections import deque
from typing import TYPE_CHECKING, Iterable, List, Optional, SupportsIndex, Union, cast
import sys
from typing import (
TYPE_CHECKING,
Any,
List,
Optional,
Union,
cast,
)

from lonelypsp.util.bounded_deque import BoundedDeque
from lonelypsp.util.cancel_and_check import cancel_and_check

from lonelypss.ws.handlers.open.check_background_tasks import check_background_tasks
from lonelypss.ws.handlers.open.check_compressors import check_compressors
from lonelypss.ws.handlers.open.check_internal_message_task import (
check_internal_message_task,
from lonelypss.ws.handlers.open.check_my_receiver_queue import (
check_my_receiver_queue,
)
from lonelypss.ws.handlers.open.check_process_task import check_process_task
from lonelypss.ws.handlers.open.check_read_task import check_read_task
Expand All @@ -20,6 +30,7 @@
from lonelypss.ws.state import (
CompressorState,
CompressorTrainingInfoType,
InternalMessageType,
SimplePendingSendPreFormatted,
State,
StateClosing,
Expand All @@ -29,6 +40,16 @@
WaitingInternalMessageType,
)

if sys.version_info < (3, 11):
from typing import NoReturn
from typing import NoReturn as Never

def assert_never(value: Never) -> NoReturn:
raise AssertionError(f"Unhandled type: {value!r}")

else:
from typing import assert_never


async def handle_open(state: State) -> State:
"""Makes some progress, waiting if necessary, and returning the new state. This
Expand All @@ -44,7 +65,7 @@ async def handle_open(state: State) -> State:
if await check_send_task(state) == CheckResult.RESTART:
return state

if await check_internal_message_task(state) == CheckResult.RESTART:
if await check_my_receiver_queue(state) == CheckResult.RESTART:
return state

if await check_read_task(state) == CheckResult.RESTART:
Expand All @@ -59,21 +80,34 @@ async def handle_open(state: State) -> State:
if await check_compressors(state) == CheckResult.RESTART:
return state

await asyncio.wait(
[
*([state.send_task] if state.send_task is not None else []),
state.internal_message_task,
state.read_task,
*([state.process_task] if state.process_task is not None else []),
*state.backgrounded,
*[
compressor.task
for compressor in state.compressors
if compressor.type == CompressorState.PREPARING
owned_tasks: List[asyncio.Task[Any]] = []
if state.my_receiver.queue.empty():
owned_tasks.append(
asyncio.create_task(state.my_receiver.queue.wait_not_empty())
)
try:
await asyncio.wait(
[
*([state.send_task] if state.send_task is not None else []),
*owned_tasks,
state.read_task,
*(
[state.process_task]
if state.process_task is not None
else []
),
*state.backgrounded,
*[
compressor.task
for compressor in state.compressors
if compressor.type == CompressorState.PREPARING
],
],
],
return_when=asyncio.FIRST_COMPLETED,
)
return_when=asyncio.FIRST_COMPLETED,
)
finally:
for task in owned_tasks:
await cancel_and_check(task)
return state
except NormalDisconnectException:
if state.send_task is not None:
Expand All @@ -84,13 +118,11 @@ async def handle_open(state: State) -> State:
asyncio.Task[None], asyncio.create_task(asyncio.Event().wait())
)
old_unsent = state.unsent_messages
state.unsent_messages = VoidingDeque()
state.unsent_messages = BoundedDeque(maxlen=0)

while old_unsent:
_cleanup(old_unsent.popleft())

state.internal_message_task.cancel()

if not _disconnected_receiver:
_disconnected_receiver = True
await _disconnect_receiver(state)
Expand Down Expand Up @@ -121,7 +153,6 @@ async def handle_open(state: State) -> State:
cleanup_exceptions.append(e2)

state.read_task.cancel()
state.internal_message_task.cancel()

if state.notify_stream_state is not None:
try:
Expand Down Expand Up @@ -198,6 +229,19 @@ async def _disconnect_receiver(state: StateOpen) -> None:
except BaseException as e:
excs.append(e)

for msg in state.my_receiver.queue.drain():
if msg.type == InternalMessageType.SMALL:
continue

if msg.type == InternalMessageType.LARGE:
msg.finished.set()
continue

if msg.type == InternalMessageType.MISSED:
continue

assert_never(msg)

if excs:
raise combine_multiple_exceptions(
"failed to properly disconnect receiver", excs
Expand All @@ -210,44 +254,3 @@ async def _disconnect_receiver(state: StateOpen) -> None:
def _cleanup(value: SendT) -> None:
if value.type == WaitingInternalMessageType.SPOOLED_LARGE:
value.stream.close()


class VoidingDeque(deque[SendT]):
def append(self, value: SendT, /) -> None:
_cleanup(value)

def appendleft(self, value: SendT, /) -> None:
_cleanup(value)

def insert(self, i: int, x: SendT, /) -> None:
_cleanup(x)

def extend(self, iterable: Iterable[SendT], /) -> None:
for v in iterable:
_cleanup(v)

def extendleft(self, iterable: Iterable[SendT], /) -> None:
for v in iterable:
_cleanup(v)

def __setitem__(
self,
key: Union[int, slice, SupportsIndex],
value: Union[SendT, Iterable[SendT]],
/,
) -> None:
if isinstance(key, slice):
for v in cast(Iterable[SendT], value):
_cleanup(v)
else:
_cleanup(cast(SendT, value))

def __iadd__(self, other: Iterable[SendT], /) -> "VoidingDeque":
for v in other:
_cleanup(v)
return self

def __add__(self, other: deque[SendT], /) -> "VoidingDeque":
for v in other:
_cleanup(v)
return self
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ async def process_notify_stream(state: StateOpen, message: S2B_NotifyStream) ->
) as decompressed_file:
hasher = hashlib.sha512()
pos = 0
body.seek(0)
with (
maybe_write_large_message_for_training(
state, first.decompressed_length
Expand Down
12 changes: 6 additions & 6 deletions src/lonelypss/ws/handlers/waiting_configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import secrets
import tempfile
import time
from collections import deque
from typing import TYPE_CHECKING, List, cast

import aiohttp
Expand All @@ -19,6 +18,8 @@
serialize_b2s_confirm_configure,
)
from lonelypsp.stateful.parser_helpers import parse_s2b_message_prefix
from lonelypsp.util.bounded_deque import BoundedDeque
from lonelypsp.util.drainable_asyncio_queue import DrainableAsyncioQueue

from lonelypss.util.websocket_message import WSMessageBytes
from lonelypss.ws.handlers.protocol import StateHandler
Expand Down Expand Up @@ -171,7 +172,6 @@ async def handle_waiting_configure(state: State) -> State:
broadcaster_counter=1,
subscriber_counter=-1,
read_task=make_websocket_read_task(state.websocket),
internal_message_task=asyncio.create_task(receiver.queue.get()),
notify_stream_state=None,
send_task=asyncio.create_task(
state.websocket.send_bytes(
Expand All @@ -186,14 +186,14 @@ async def handle_waiting_configure(state: State) -> State:
)
),
process_task=None,
unprocessed_messages=deque(
unprocessed_messages=BoundedDeque(
maxlen=state.broadcaster_config.websocket_max_unprocessed_receives
),
unsent_messages=deque(
unsent_messages=BoundedDeque(
maxlen=state.broadcaster_config.websocket_max_pending_sends
),
expecting_acks=asyncio.Queue(
maxsize=state.broadcaster_config.websocket_send_max_unacknowledged or 0
expecting_acks=DrainableAsyncioQueue(
max_size=state.broadcaster_config.websocket_send_max_unacknowledged or 0
),
backgrounded=set(),
)
Expand Down
4 changes: 3 additions & 1 deletion src/lonelypss/ws/simple_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import re
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type

from lonelypsp.util.drainable_asyncio_queue import DrainableAsyncioQueue

from lonelypss.util.sync_io import SyncReadableBytesIO
from lonelypss.ws.state import (
AsyncioWSReceiver,
Expand All @@ -19,7 +21,7 @@ def __init__(self) -> None:
self.glob_subscriptions: List[Tuple[re.Pattern, str]] = []
self.receiver_id: Optional[int] = None

self.queue: asyncio.Queue[InternalMessage] = asyncio.Queue()
self.queue: DrainableAsyncioQueue[InternalMessage] = DrainableAsyncioQueue()

def is_relevant(self, topic: bytes) -> bool:
if topic in self.exact_subscriptions:
Expand Down
Loading

0 comments on commit c495506

Please sign in to comment.