Skip to content

Commit

Permalink
Fix available connection calculation (#9623)
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Nov 1, 2024
1 parent 9f9bab2 commit 898aa28
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
16 changes: 11 additions & 5 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,14 +475,20 @@ def _available_connections(self, key: "ConnectionKey") -> int:
available.
"""
# check total available connections
if self._limit and (available := self._limit - len(self._acquired)) <= 0:
return available
# If there are no limits, this will always return 1
total_remain = 1

if self._limit and (total_remain := self._limit - len(self._acquired)) <= 0:
return total_remain

# check limit per host
if self._limit_per_host and key in self._acquired_per_host:
return self._limit_per_host - len(self._acquired_per_host[key])
if host_remain := self._limit_per_host:
if acquired := self._acquired_per_host.get(key):
host_remain -= len(acquired)
if total_remain > host_remain:
return host_remain

return 1
return total_remain

async def connect(
self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout"
Expand Down
80 changes: 80 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def key2() -> ConnectionKey:
return ConnectionKey("localhost", 80, False, True, None, None, None)


@pytest.fixture
def other_host_key2() -> ConnectionKey:
# Connection key
return ConnectionKey("otherhost", 80, False, True, None, None, None)


@pytest.fixture
def ssl_key() -> ConnectionKey:
# Connection key
Expand Down Expand Up @@ -3390,3 +3396,77 @@ def test_default_ssl_context_creation_without_ssl() -> None:
with mock.patch.object(connector_module, "ssl", None):
assert connector_module._make_ssl_context(False) is None
assert connector_module._make_ssl_context(True) is None


async def test_available_connections_with_limit_per_host(
key: ConnectionKey, other_host_key2: ConnectionKey
) -> None:
"""Verify expected values based on active connections with host limit."""
conn = aiohttp.BaseConnector(limit=3, limit_per_host=2)
assert conn._available_connections(key) == 2
assert conn._available_connections(other_host_key2) == 2
proto1 = create_mocked_conn()
connection1 = conn._acquired_connection(proto1, key)
assert conn._available_connections(key) == 1
assert conn._available_connections(other_host_key2) == 2
proto2 = create_mocked_conn()
connection2 = conn._acquired_connection(proto2, key)
assert conn._available_connections(key) == 0
assert conn._available_connections(other_host_key2) == 1
connection1.close()
assert conn._available_connections(key) == 1
assert conn._available_connections(other_host_key2) == 2
connection2.close()
other_proto1 = create_mocked_conn()
other_connection1 = conn._acquired_connection(other_proto1, other_host_key2)
assert conn._available_connections(key) == 2
assert conn._available_connections(other_host_key2) == 1
other_connection1.close()
assert conn._available_connections(key) == 2
assert conn._available_connections(other_host_key2) == 2


@pytest.mark.parametrize("limit_per_host", [0, 10])
async def test_available_connections_without_limit_per_host(
key: ConnectionKey, other_host_key2: ConnectionKey, limit_per_host: int
) -> None:
"""Verify expected values based on active connections with higher host limit."""
conn = aiohttp.BaseConnector(limit=3, limit_per_host=limit_per_host)
assert conn._available_connections(key) == 3
assert conn._available_connections(other_host_key2) == 3
proto1 = create_mocked_conn()
connection1 = conn._acquired_connection(proto1, key)
assert conn._available_connections(key) == 2
assert conn._available_connections(other_host_key2) == 2
proto2 = create_mocked_conn()
connection2 = conn._acquired_connection(proto2, key)
assert conn._available_connections(key) == 1
assert conn._available_connections(other_host_key2) == 1
connection1.close()
assert conn._available_connections(key) == 2
assert conn._available_connections(other_host_key2) == 2
connection2.close()
other_proto1 = create_mocked_conn()
other_connection1 = conn._acquired_connection(other_proto1, other_host_key2)
assert conn._available_connections(key) == 2
assert conn._available_connections(other_host_key2) == 2
other_connection1.close()
assert conn._available_connections(key) == 3
assert conn._available_connections(other_host_key2) == 3


async def test_available_connections_no_limits(
key: ConnectionKey, other_host_key2: ConnectionKey
) -> None:
"""Verify expected values based on active connections with no limits."""
# No limits is a special case where available connections should always be 1.
conn = aiohttp.BaseConnector(limit=0, limit_per_host=0)
assert conn._available_connections(key) == 1
assert conn._available_connections(other_host_key2) == 1
proto1 = create_mocked_conn()
connection1 = conn._acquired_connection(proto1, key)
assert conn._available_connections(key) == 1
assert conn._available_connections(other_host_key2) == 1
connection1.close()
assert conn._available_connections(key) == 1
assert conn._available_connections(other_host_key2) == 1

0 comments on commit 898aa28

Please sign in to comment.