Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement reader.set_max_in_flight() #101

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 124 additions & 35 deletions ansq/tcp/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
TYPE_CHECKING,
Any,
AsyncIterator,
Awaitable,
Callable,
Dict,
List,
NamedTuple,
Expand Down Expand Up @@ -54,6 +56,12 @@ def __init__(
self._channel = channel
self._loop = loop or asyncio.get_event_loop()
self._lookupd: Optional["Lookupd"] = None
self._max_in_flight = 1
self._rdy_state_distributor = RdyStateDistributor(
reader=self,
interval=1000,
loop=self._loop,
)

# Common message queue for all connections
self._message_queue: "asyncio.Queue[Optional[NSQMessage]]" = asyncio.Queue()
Expand All @@ -78,6 +86,8 @@ async def connect(self) -> None:
Queries lookupd if specified.
"""
await super().connect()
await self._rdy_state_distributor.distribute()
await self._rdy_state_distributor.start_redistributing()

if self._lookupd:
# Do first lookup manually
Expand Down Expand Up @@ -127,12 +137,8 @@ def message_queue(self) -> "asyncio.Queue[Optional['NSQMessage']]":

@property
def max_in_flight(self) -> int:
"""Return 'max_in_flight' number.

Currently, it equals to number of current connections where every connection
has RDY=1.
"""
return len(self._connections)
"""Return 'max_in_flight' number."""
return self._max_in_flight

async def set_max_in_flight(self, count: int) -> None:
"""Update 'max_in_flight' number.
Expand All @@ -141,7 +147,8 @@ async def set_max_in_flight(self, count: int) -> None:
nsqd expects a response. It effects how RDY state is managed. For more detail
see the doc: https://nsq.io/clients/building_client_libraries.html#rdy-state
"""
raise NotImplementedError("Update max_in_flight not implemented yet")
self._max_in_flight = count
await self._rdy_state_distributor.distribute()

async def connect_to_nsqd(self, host: str, port: int) -> "NSQConnection":
"""Connect, identify and subscribe to nsqd by given host and port."""
Expand All @@ -159,6 +166,8 @@ async def close(self) -> None:
if self._lookupd is not None:
await self._lookupd.close()

await self._rdy_state_distributor.stop_redistributing()

await super().close()


Expand All @@ -175,8 +184,6 @@ def __init__(
debug: bool = False,
):
self._reader = reader
self._poll_interval = poll_interval / 1000
self._poll_jitter = poll_jitter
self._loop = loop or asyncio.get_event_loop()
self._query_lookupd_attempts = 0
self._logger = get_logger(debug, "lookupd")
Expand All @@ -202,6 +209,13 @@ def __init__(
NsqLookupd.from_address(address) for address in http_addresses
]

self._periodic_callback = PeriodicCallback(
callback=self.query_lookup,
interval=poll_interval,
jitter=poll_jitter,
loop=loop,
)

async def query_lookup(self) -> None:
"""Query lookupd for topic producers and connect to them."""
# Get lookupd connection in a round robin fashion way
Expand All @@ -223,35 +237,11 @@ async def query_lookup(self) -> None:
for address in producer_addresses:
await self._reader.connect_to_nsqd(address.host, address.port)

async def poll_lookup(self) -> NoReturn:
"""Poll ``query_lookup()`` infinitely."""
# Add a delay to poll which helps to distribute evenly requests
# even if multiple readers restart at the same time.
delay = self._poll_interval * self._poll_jitter
await asyncio.sleep(random.random() * delay)

# Poll infinitely lookup
while True:
await asyncio.sleep(self._poll_interval)
await self.query_lookup()

async def start_polling(self) -> None:
"""Start polling lookupd."""
# Polling is already started
if self._poll_lookup_task is not None and not self._poll_lookup_task.done():
return

# Start polling task
self._poll_lookup_task = self._loop.create_task(self.poll_lookup())
await self._periodic_callback.start()

async def stop_polling(self) -> None:
"""Stop polling lookupd."""
if self._poll_lookup_task is None:
return

self._poll_lookup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._poll_lookup_task
await self._periodic_callback.stop()

async def close(self) -> None:
"""Close all lookupd connections and stop poll lookup task."""
Expand Down Expand Up @@ -341,3 +331,102 @@ async def create_reader(
)
await reader.connect()
return reader


class RdyStateDistributor:
"""A class that manages RDY state for all connections in a reader."""

def __init__(
self,
reader: Reader,
interval: float,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._reader = reader
self._periodic_callback = PeriodicCallback(
callback=self.distribute,
interval=interval,
loop=loop,
)

async def start_redistributing(self) -> None:
await self._periodic_callback.start()

async def stop_redistributing(self) -> None:
await self._periodic_callback.stop()

async def distribute(self) -> None:
"""Distribute evenly RDY state across all connections."""

# TODO: https://nsq.io/clients/building_client_libraries.html#rdy-state.
# - different flow for max_in_flight < num_conns
# - starvation flow
# - backoff

max_in_flight = self._reader.max_in_flight

open_conns = [conn for conn in self._reader.connections if conn.is_connected]
if not open_conns:
return

# disable RDY state for all connections
if max_in_flight == 0:
for conn in open_conns:
await conn.rdy(0)
return

rdy_per_conn = min(max(1, max_in_flight // len(open_conns)), max_in_flight)

# distribute evenly the rdy count to all connections
for conn in open_conns:
if max_in_flight <= 0:
await conn.rdy(0)
break

await conn.rdy(rdy_per_conn)
max_in_flight -= rdy_per_conn

# distribute the remaining rdy count to the last connection
if max_in_flight > 0:
await open_conns[-1].rdy(rdy_per_conn + max_in_flight)


class PeriodicCallback:
"""A class that runs a callback periodically."""

def __init__(
self,
callback: Callable[[], Awaitable[None]],
interval: float,
jitter: float = 0,
*,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._callback = callback
self._interval = interval // 1000
self._jitter = jitter
self._loop = loop or asyncio.get_event_loop()
self._task: Optional[asyncio.Task[None]] = None

async def start(self) -> None:
if self._task is not None:
return

self._task = self._loop.create_task(self._run())

async def stop(self) -> None:
if self._task is None:
return

self._task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._task

async def _run(self) -> NoReturn:
if self._jitter:
delay = self._interval * self._jitter
await asyncio.sleep(random.random() * delay)

while True:
await asyncio.sleep(self._interval)
await self._callback()
59 changes: 59 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ async def test_read_from_multiple_tcp_addresses(nsqd, nsqd2):
channel="bar",
nsqd_tcp_addresses=[nsqd.tcp_address, nsqd2.tcp_address],
)
await reader.set_max_in_flight(2)

nsq1 = await open_connection(nsqd.host, nsqd.port)
await nsq1.pub(topic="foo", message="test_message1")
Expand All @@ -97,3 +98,61 @@ async def test_read_from_multiple_tcp_addresses(nsqd, nsqd2):
assert message.body == b"test_message2"

await reader.close()


async def test_set_max_in_flight(nsqd):
reader = await create_reader(topic="foo", channel="bar")

await reader.set_max_in_flight(7)

assert reader.max_in_flight == 7

await reader.close()


@pytest.mark.parametrize(
"max_in_flight, expected_rdys",
(
(0, (0, 0)),
(1, (1, 0)),
(2, (1, 1)),
(3, (1, 2)),
(4, (2, 2)),
(5, (2, 3)),
),
)
async def test_distribute_evenly_max_in_flight(
nsqd, nsqd2, max_in_flight, expected_rdys
):
reader = await create_reader(
topic="foo",
channel="bar",
nsqd_tcp_addresses=[nsqd.tcp_address, nsqd2.tcp_address],
)

await reader.set_max_in_flight(max_in_flight)
assert get_rdys(reader) == expected_rdys

await reader.close()


async def test_redistribute_max_in_flight_on_close_connection(nsqd, nsqd2, wait_for):
reader = await create_reader(
topic="foo",
channel="bar",
nsqd_tcp_addresses=[nsqd.tcp_address, nsqd2.tcp_address],
)

await reader.set_max_in_flight(5)
assert get_rdys(reader) == (2, 3)

await reader.connections[0].close()
await wait_for(lambda: get_rdys(reader) == (5,))

await reader.close()


def get_rdys(reader) -> tuple[int, ...]:
return tuple(
conn.rdy_messages_count for conn in reader.connections if conn.is_connected
)