From 7b2f28127ff4200a4d009e4fe1b1868d856c24d8 Mon Sep 17 00:00:00 2001 From: Albert Tugushev Date: Sat, 3 Jun 2023 22:47:36 +0800 Subject: [PATCH] Implement naive reader.set_max_in_flight() --- ansq/tcp/reader.py | 159 +++++++++++++++++++++++++++++++++---------- tests/test_reader.py | 59 ++++++++++++++++ 2 files changed, 183 insertions(+), 35 deletions(-) diff --git a/ansq/tcp/reader.py b/ansq/tcp/reader.py index 32bae46..1c26b37 100644 --- a/ansq/tcp/reader.py +++ b/ansq/tcp/reader.py @@ -6,6 +6,8 @@ TYPE_CHECKING, Any, AsyncIterator, + Awaitable, + Callable, Dict, List, NamedTuple, @@ -54,6 +56,12 @@ def __init__( self._channel = channel self._loop = loop or asyncio.get_event_loop() self._lookupd: Optional["Lookupd"] = None + self._max_in_flight = 1 + self._rdy_state_distributor = RdyStateDistributor( + reader=self, + interval=1000, + loop=self._loop, + ) # Common message queue for all connections self._message_queue: "asyncio.Queue[Optional[NSQMessage]]" = asyncio.Queue() @@ -78,6 +86,8 @@ async def connect(self) -> None: Queries lookupd if specified. """ await super().connect() + await self._rdy_state_distributor.distribute() + await self._rdy_state_distributor.start_redistributing() if self._lookupd: # Do first lookup manually @@ -127,12 +137,8 @@ def message_queue(self) -> "asyncio.Queue[Optional['NSQMessage']]": @property def max_in_flight(self) -> int: - """Return 'max_in_flight' number. - - Currently, it equals to number of current connections where every connection - has RDY=1. - """ - return len(self._connections) + """Return 'max_in_flight' number.""" + return self._max_in_flight async def set_max_in_flight(self, count: int) -> None: """Update 'max_in_flight' number. @@ -141,7 +147,8 @@ async def set_max_in_flight(self, count: int) -> None: nsqd expects a response. It effects how RDY state is managed. For more detail see the doc: https://nsq.io/clients/building_client_libraries.html#rdy-state """ - raise NotImplementedError("Update max_in_flight not implemented yet") + self._max_in_flight = count + await self._rdy_state_distributor.distribute() async def connect_to_nsqd(self, host: str, port: int) -> "NSQConnection": """Connect, identify and subscribe to nsqd by given host and port.""" @@ -159,6 +166,8 @@ async def close(self) -> None: if self._lookupd is not None: await self._lookupd.close() + await self._rdy_state_distributor.stop_redistributing() + await super().close() @@ -175,8 +184,6 @@ def __init__( debug: bool = False, ): self._reader = reader - self._poll_interval = poll_interval / 1000 - self._poll_jitter = poll_jitter self._loop = loop or asyncio.get_event_loop() self._query_lookupd_attempts = 0 self._logger = get_logger(debug, "lookupd") @@ -202,6 +209,13 @@ def __init__( NsqLookupd.from_address(address) for address in http_addresses ] + self._periodic_callback = PeriodicCallback( + callback=self.query_lookup, + interval=poll_interval, + jitter=poll_jitter, + loop=loop, + ) + async def query_lookup(self) -> None: """Query lookupd for topic producers and connect to them.""" # Get lookupd connection in a round robin fashion way @@ -223,35 +237,11 @@ async def query_lookup(self) -> None: for address in producer_addresses: await self._reader.connect_to_nsqd(address.host, address.port) - async def poll_lookup(self) -> NoReturn: - """Poll ``query_lookup()`` infinitely.""" - # Add a delay to poll which helps to distribute evenly requests - # even if multiple readers restart at the same time. - delay = self._poll_interval * self._poll_jitter - await asyncio.sleep(random.random() * delay) - - # Poll infinitely lookup - while True: - await asyncio.sleep(self._poll_interval) - await self.query_lookup() - async def start_polling(self) -> None: - """Start polling lookupd.""" - # Polling is already started - if self._poll_lookup_task is not None and not self._poll_lookup_task.done(): - return - - # Start polling task - self._poll_lookup_task = self._loop.create_task(self.poll_lookup()) + await self._periodic_callback.start() async def stop_polling(self) -> None: - """Stop polling lookupd.""" - if self._poll_lookup_task is None: - return - - self._poll_lookup_task.cancel() - with contextlib.suppress(asyncio.CancelledError): - await self._poll_lookup_task + await self._periodic_callback.stop() async def close(self) -> None: """Close all lookupd connections and stop poll lookup task.""" @@ -341,3 +331,102 @@ async def create_reader( ) await reader.connect() return reader + + +class RdyStateDistributor: + """A class that manages RDY state for all connections in a reader.""" + + def __init__( + self, + reader: Reader, + interval: float, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + self._reader = reader + self._periodic_callback = PeriodicCallback( + callback=self.distribute, + interval=interval, + loop=loop, + ) + + async def start_redistributing(self) -> None: + await self._periodic_callback.start() + + async def stop_redistributing(self) -> None: + await self._periodic_callback.stop() + + async def distribute(self) -> None: + """Distribute evenly RDY state across all connections.""" + + # TODO: https://nsq.io/clients/building_client_libraries.html#rdy-state. + # - different flow for max_in_flight < num_conns + # - starvation flow + # - backoff + + max_in_flight = self._reader.max_in_flight + + open_conns = [conn for conn in self._reader.connections if conn.is_connected] + if not open_conns: + return + + # disable RDY state for all connections + if max_in_flight == 0: + for conn in open_conns: + await conn.rdy(0) + return + + rdy_per_conn = min(max(1, max_in_flight // len(open_conns)), max_in_flight) + + # distribute evenly the rdy count to all connections + for conn in open_conns: + if max_in_flight <= 0: + await conn.rdy(0) + break + + await conn.rdy(rdy_per_conn) + max_in_flight -= rdy_per_conn + + # distribute the remaining rdy count to the last connection + if max_in_flight > 0: + await open_conns[-1].rdy(rdy_per_conn + max_in_flight) + + +class PeriodicCallback: + """A class that runs a callback periodically.""" + + def __init__( + self, + callback: Callable[[], Awaitable[None]], + interval: float, + jitter: float = 0, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + self._callback = callback + self._interval = interval // 1000 + self._jitter = jitter + self._loop = loop or asyncio.get_event_loop() + self._task: Optional[asyncio.Task[None]] = None + + async def start(self) -> None: + if self._task is not None: + return + + self._task = self._loop.create_task(self._run()) + + async def stop(self) -> None: + if self._task is None: + return + + self._task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._task + + async def _run(self) -> NoReturn: + if self._jitter: + delay = self._interval * self._jitter + await asyncio.sleep(random.random() * delay) + + while True: + await asyncio.sleep(self._interval) + await self._callback() diff --git a/tests/test_reader.py b/tests/test_reader.py index 57b33a3..9d654f5 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -79,6 +79,7 @@ async def test_read_from_multiple_tcp_addresses(nsqd, nsqd2): channel="bar", nsqd_tcp_addresses=[nsqd.tcp_address, nsqd2.tcp_address], ) + await reader.set_max_in_flight(2) nsq1 = await open_connection(nsqd.host, nsqd.port) await nsq1.pub(topic="foo", message="test_message1") @@ -97,3 +98,61 @@ async def test_read_from_multiple_tcp_addresses(nsqd, nsqd2): assert message.body == b"test_message2" await reader.close() + + +async def test_set_max_in_flight(nsqd): + reader = await create_reader(topic="foo", channel="bar") + + await reader.set_max_in_flight(7) + + assert reader.max_in_flight == 7 + + await reader.close() + + +@pytest.mark.parametrize( + "max_in_flight, expected_rdys", + ( + (0, (0, 0)), + (1, (1, 0)), + (2, (1, 1)), + (3, (1, 2)), + (4, (2, 2)), + (5, (2, 3)), + ), +) +async def test_distribute_evenly_max_in_flight( + nsqd, nsqd2, max_in_flight, expected_rdys +): + reader = await create_reader( + topic="foo", + channel="bar", + nsqd_tcp_addresses=[nsqd.tcp_address, nsqd2.tcp_address], + ) + + await reader.set_max_in_flight(max_in_flight) + assert get_rdys(reader) == expected_rdys + + await reader.close() + + +async def test_redistribute_max_in_flight_on_close_connection(nsqd, nsqd2, wait_for): + reader = await create_reader( + topic="foo", + channel="bar", + nsqd_tcp_addresses=[nsqd.tcp_address, nsqd2.tcp_address], + ) + + await reader.set_max_in_flight(5) + assert get_rdys(reader) == (2, 3) + + await reader.connections[0].close() + await wait_for(lambda: get_rdys(reader) == (5,)) + + await reader.close() + + +def get_rdys(reader) -> tuple[int, ...]: + return tuple( + conn.rdy_messages_count for conn in reader.connections if conn.is_connected + )