@@ -254,9 +254,9 @@ class PyMongoBaseProtocol(Protocol):
254
254
def __init__ (self , timeout : Optional [float ] = None ):
255
255
self .transport : Transport = None # type: ignore[assignment]
256
256
self ._timeout = timeout
257
- self ._closing_error = asyncio .get_running_loop ().create_future ()
258
257
self ._closed = asyncio .get_running_loop ().create_future ()
259
258
self ._connection_lost = False
259
+ self ._closing_exception = None
260
260
261
261
def settimeout (self , timeout : float | None ) -> None :
262
262
self ._timeout = timeout
@@ -270,11 +270,11 @@ def close(self, exc: Optional[Exception] = None) -> None:
270
270
self .transport .abort ()
271
271
self ._resolve_pending (exc )
272
272
self ._connection_lost = True
273
+ self ._closing_exception = exc
273
274
274
275
def connection_lost (self , exc : Optional [Exception ] = None ) -> None :
275
- if exc is not None and not self ._closing_error .done ():
276
- self ._closing_error .set_exception (exc )
277
276
self ._resolve_pending (exc )
277
+ self ._closing_exception = exc
278
278
if not self ._closed .done ():
279
279
self ._closed .set_result (None )
280
280
@@ -325,7 +325,6 @@ def connection_made(self, transport: BaseTransport) -> None:
325
325
"""
326
326
self .transport = transport # type: ignore[assignment]
327
327
self .transport .set_write_buffer_limits (MAX_MESSAGE_SIZE , MAX_MESSAGE_SIZE )
328
- super ().connection_made (self )
329
328
330
329
async def read (self , request_id : Optional [int ], max_message_size : int ) -> tuple [bytes , int ]:
331
330
"""Read a single MongoDB Wire Protocol message from this connection."""
@@ -339,8 +338,6 @@ async def read(self, request_id: Optional[int], max_message_size: int) -> tuple[
339
338
if self ._done_messages :
340
339
message = await self ._done_messages .popleft ()
341
340
else :
342
- if self .transport and self .transport .is_closing ():
343
- return await self ._closing_error
344
341
read_waiter = asyncio .get_running_loop ().create_future ()
345
342
self ._pending_messages .append (read_waiter )
346
343
try :
@@ -478,6 +475,7 @@ def _resolve_pending(self, exc: Optional[Exception] = None) -> None:
478
475
else :
479
476
msg .set_exception (exc )
480
477
self ._done_messages .append (msg )
478
+ self ._pending_messages .clear ()
481
479
482
480
483
481
class PyMongoKMSProtocol (PyMongoBaseProtocol ):
@@ -493,7 +491,6 @@ def connection_made(self, transport: BaseTransport) -> None:
493
491
The transport argument is the transport representing the write side of the connection.
494
492
"""
495
493
self .transport = transport # type: ignore[assignment]
496
- super ().connection_made (self )
497
494
498
495
def data_received (self , data : bytes ) -> None :
499
496
if self ._connection_lost :
0 commit comments