diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index e6d6d709..a68a62b5 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -15,7 +15,6 @@ LocalProtocolError, RemoteProtocolError, WriteError, - map_exceptions, ) from .._models import Origin, Request, Response from .._synchronization import AsyncLock, AsyncShieldCancellation @@ -141,12 +140,14 @@ async def _send_request_headers(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) - with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + try: event = h11.Request( method=request.method, target=request.url.target, headers=request.headers, ) + except h11.LocalProtocolError as exc: + raise LocalProtocolError(exc) from exc await self._send_event(event, timeout=timeout) async def _send_request_body(self, request: Request) -> None: @@ -210,8 +211,10 @@ async def _receive_event( self, timeout: float | None = None ) -> h11.Event | type[h11.PAUSED]: while True: - with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + try: event = self._h11_state.next_event() + except h11.RemoteProtocolError as exc: + raise RemoteProtocolError(exc) from exc if event is h11.NEED_DATA: data = await self._network_stream.read( diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py index a140095e..0f9cee22 100644 --- a/httpcore/_backends/anyio.py +++ b/httpcore/_backends/anyio.py @@ -12,7 +12,6 @@ ReadTimeout, WriteError, WriteTimeout, - map_exceptions, ) from .._utils import is_socket_readable from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream @@ -23,31 +22,32 @@ def __init__(self, stream: anyio.abc.ByteStream) -> None: self._stream = stream async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - exc_map = { - TimeoutError: ReadTimeout, - anyio.BrokenResourceError: ReadError, - anyio.ClosedResourceError: ReadError, - anyio.EndOfStream: ReadError, - } - with map_exceptions(exc_map): + try: with anyio.fail_after(timeout): try: return await self._stream.receive(max_bytes=max_bytes) except anyio.EndOfStream: # pragma: nocover return b"" + except TimeoutError as exc: + raise ReadTimeout(exc) from exc + except ( + anyio.BrokenResourceError, + anyio.ClosedResourceError, + anyio.EndOfStream, + ) as exc: + raise ReadError(exc) from exc async def write(self, buffer: bytes, timeout: float | None = None) -> None: if not buffer: return - exc_map = { - TimeoutError: WriteTimeout, - anyio.BrokenResourceError: WriteError, - anyio.ClosedResourceError: WriteError, - } - with map_exceptions(exc_map): + try: with anyio.fail_after(timeout): await self._stream.send(item=buffer) + except TimeoutError as exc: + raise WriteTimeout(exc) from exc + except (anyio.BrokenResourceError, anyio.ClosedResourceError) as exc: + raise WriteError(exc) from exc async def aclose(self) -> None: await self._stream.aclose() @@ -58,13 +58,7 @@ async def start_tls( server_hostname: str | None = None, timeout: float | None = None, ) -> AsyncNetworkStream: - exc_map = { - TimeoutError: ConnectTimeout, - anyio.BrokenResourceError: ConnectError, - anyio.EndOfStream: ConnectError, - ssl.SSLError: ConnectError, - } - with map_exceptions(exc_map): + try: try: with anyio.fail_after(timeout): ssl_stream = await anyio.streams.tls.TLSStream.wrap( @@ -77,6 +71,10 @@ async def start_tls( except Exception as exc: # pragma: nocover await self.aclose() raise exc + except TimeoutError as exc: + raise ConnectTimeout(exc) from exc + except (anyio.BrokenResourceError, anyio.EndOfStream, ssl.SSLError) as exc: + raise ConnectError(exc) from exc return AnyIOStream(ssl_stream) def get_extra_info(self, info: str) -> typing.Any: @@ -105,12 +103,7 @@ async def connect_tcp( ) -> AsyncNetworkStream: # pragma: nocover if socket_options is None: socket_options = [] - exc_map = { - TimeoutError: ConnectTimeout, - OSError: ConnectError, - anyio.BrokenResourceError: ConnectError, - } - with map_exceptions(exc_map): + try: with anyio.fail_after(timeout): stream: anyio.abc.ByteStream = await anyio.connect_tcp( remote_host=host, @@ -120,6 +113,10 @@ async def connect_tcp( # By default TCP sockets opened in `asyncio` include TCP_NODELAY. for option in socket_options: stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + except TimeoutError as exc: + raise ConnectTimeout(exc) from exc + except (OSError, anyio.BrokenResourceError) as exc: + raise ConnectError(exc) from exc return AnyIOStream(stream) async def connect_unix_socket( @@ -130,16 +127,15 @@ async def connect_unix_socket( ) -> AsyncNetworkStream: # pragma: nocover if socket_options is None: socket_options = [] - exc_map = { - TimeoutError: ConnectTimeout, - OSError: ConnectError, - anyio.BrokenResourceError: ConnectError, - } - with map_exceptions(exc_map): + try: with anyio.fail_after(timeout): stream: anyio.abc.ByteStream = await anyio.connect_unix(path) for option in socket_options: stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + except TimeoutError as exc: + raise ConnectTimeout(exc) from exc + except (OSError, anyio.BrokenResourceError) as exc: + raise ConnectError(exc) from exc return AnyIOStream(stream) async def sleep(self, seconds: float) -> None: diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py index 4018a09c..0f135e65 100644 --- a/httpcore/_backends/sync.py +++ b/httpcore/_backends/sync.py @@ -9,12 +9,10 @@ from .._exceptions import ( ConnectError, ConnectTimeout, - ExceptionMapping, ReadError, ReadTimeout, WriteError, WriteTimeout, - map_exceptions, ) from .._utils import is_socket_readable from .base import SOCKET_OPTION, NetworkBackend, NetworkStream @@ -77,20 +75,26 @@ def _perform_io( return ret def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} - with map_exceptions(exc_map): + try: self._sock.settimeout(timeout) return typing.cast( bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes)) ) + except socket.timeout as exc: + raise ReadTimeout(exc) from exc + except OSError as exc: + raise ReadError(exc) from exc def write(self, buffer: bytes, timeout: float | None = None) -> None: - exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} - with map_exceptions(exc_map): + try: self._sock.settimeout(timeout) while buffer: nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer)) buffer = buffer[nsent:] + except socket.timeout as exc: + raise WriteTimeout(exc) from exc + except OSError as exc: + raise WriteError(exc) from exc def close(self) -> None: self._sock.close() @@ -122,21 +126,27 @@ def __init__(self, sock: socket.socket) -> None: self._sock = sock def read(self, max_bytes: int, timeout: float | None = None) -> bytes: - exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError} - with map_exceptions(exc_map): + try: self._sock.settimeout(timeout) return self._sock.recv(max_bytes) + except socket.timeout as exc: + raise ReadTimeout(exc) from exc + except OSError as exc: + raise ReadError(exc) from exc def write(self, buffer: bytes, timeout: float | None = None) -> None: if not buffer: return - exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError} - with map_exceptions(exc_map): + try: while buffer: self._sock.settimeout(timeout) n = self._sock.send(buffer) buffer = buffer[n:] + except socket.timeout as exc: + raise WriteTimeout(exc) from exc + except OSError as exc: + raise WriteError(exc) from exc def close(self) -> None: self._sock.close() @@ -147,11 +157,7 @@ def start_tls( server_hostname: str | None = None, timeout: float | None = None, ) -> NetworkStream: - exc_map: ExceptionMapping = { - socket.timeout: ConnectTimeout, - OSError: ConnectError, - } - with map_exceptions(exc_map): + try: try: if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover # If the underlying socket has already been upgraded @@ -168,6 +174,10 @@ def start_tls( except Exception as exc: # pragma: nocover self.close() raise exc + except socket.timeout as exc: + raise ConnectTimeout(exc) from exc + except OSError as exc: + raise ConnectError(exc) from exc return SyncStream(sock) def get_extra_info(self, info: str) -> typing.Any: @@ -199,12 +209,8 @@ def connect_tcp( socket_options = [] # pragma: no cover address = (host, port) source_address = None if local_address is None else (local_address, 0) - exc_map: ExceptionMapping = { - socket.timeout: ConnectTimeout, - OSError: ConnectError, - } - with map_exceptions(exc_map): + try: sock = socket.create_connection( address, timeout, @@ -213,6 +219,10 @@ def connect_tcp( for option in socket_options: sock.setsockopt(*option) # pragma: no cover sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + except socket.timeout as exc: + raise ConnectTimeout(exc) from exc + except OSError as exc: + raise ConnectError(exc) from exc return SyncStream(sock) def connect_unix_socket( @@ -228,14 +238,14 @@ def connect_unix_socket( if socket_options is None: socket_options = [] - exc_map: ExceptionMapping = { - socket.timeout: ConnectTimeout, - OSError: ConnectError, - } - with map_exceptions(exc_map): + try: sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) for option in socket_options: sock.setsockopt(*option) sock.settimeout(timeout) sock.connect(path) + except socket.timeout as exc: + raise ConnectTimeout(exc) from exc + except OSError as exc: + raise ConnectError(exc) from exc return SyncStream(sock) diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py index 6f53f5f2..88c72563 100644 --- a/httpcore/_backends/trio.py +++ b/httpcore/_backends/trio.py @@ -8,12 +8,10 @@ from .._exceptions import ( ConnectError, ConnectTimeout, - ExceptionMapping, ReadError, ReadTimeout, WriteError, WriteTimeout, - map_exceptions, ) from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream @@ -24,29 +22,27 @@ def __init__(self, stream: trio.abc.Stream) -> None: async def read(self, max_bytes: int, timeout: float | None = None) -> bytes: timeout_or_inf = float("inf") if timeout is None else timeout - exc_map: ExceptionMapping = { - trio.TooSlowError: ReadTimeout, - trio.BrokenResourceError: ReadError, - trio.ClosedResourceError: ReadError, - } - with map_exceptions(exc_map): + try: with trio.fail_after(timeout_or_inf): data: bytes = await self._stream.receive_some(max_bytes=max_bytes) return data + except trio.TooSlowError as exc: + raise ReadTimeout(exc) from exc + except (trio.BrokenResourceError, trio.ClosedResourceError) as exc: + raise ReadError(exc) from exc async def write(self, buffer: bytes, timeout: float | None = None) -> None: if not buffer: return timeout_or_inf = float("inf") if timeout is None else timeout - exc_map: ExceptionMapping = { - trio.TooSlowError: WriteTimeout, - trio.BrokenResourceError: WriteError, - trio.ClosedResourceError: WriteError, - } - with map_exceptions(exc_map): + try: with trio.fail_after(timeout_or_inf): await self._stream.send_all(data=buffer) + except trio.TooSlowError as exc: + raise WriteTimeout(exc) from exc + except (trio.BrokenResourceError, trio.ClosedResourceError) as exc: + raise WriteError(exc) from exc async def aclose(self) -> None: await self._stream.aclose() @@ -58,10 +54,6 @@ async def start_tls( timeout: float | None = None, ) -> AsyncNetworkStream: timeout_or_inf = float("inf") if timeout is None else timeout - exc_map: ExceptionMapping = { - trio.TooSlowError: ConnectTimeout, - trio.BrokenResourceError: ConnectError, - } ssl_stream = trio.SSLStream( self._stream, ssl_context=ssl_context, @@ -69,13 +61,17 @@ async def start_tls( https_compatible=True, server_side=False, ) - with map_exceptions(exc_map): + try: try: with trio.fail_after(timeout_or_inf): await ssl_stream.do_handshake() except Exception as exc: # pragma: nocover await self.aclose() raise exc + except trio.TooSlowError as exc: + raise ConnectTimeout(exc) from exc + except trio.BrokenResourceError as exc: + raise ConnectError(exc) from exc return TrioStream(ssl_stream) def get_extra_info(self, info: str) -> typing.Any: @@ -120,18 +116,17 @@ async def connect_tcp( if socket_options is None: socket_options = [] # pragma: no cover timeout_or_inf = float("inf") if timeout is None else timeout - exc_map: ExceptionMapping = { - trio.TooSlowError: ConnectTimeout, - trio.BrokenResourceError: ConnectError, - OSError: ConnectError, - } - with map_exceptions(exc_map): + try: with trio.fail_after(timeout_or_inf): stream: trio.abc.Stream = await trio.open_tcp_stream( host=host, port=port, local_address=local_address ) for option in socket_options: stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + except trio.TooSlowError as exc: + raise ConnectTimeout(exc) from exc + except (trio.BrokenResourceError, OSError) as exc: + raise ConnectError(exc) from exc return TrioStream(stream) async def connect_unix_socket( @@ -143,16 +138,15 @@ async def connect_unix_socket( if socket_options is None: socket_options = [] timeout_or_inf = float("inf") if timeout is None else timeout - exc_map: ExceptionMapping = { - trio.TooSlowError: ConnectTimeout, - trio.BrokenResourceError: ConnectError, - OSError: ConnectError, - } - with map_exceptions(exc_map): + try: with trio.fail_after(timeout_or_inf): stream: trio.abc.Stream = await trio.open_unix_socket(path) for option in socket_options: stream.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover + except trio.TooSlowError as exc: + raise ConnectTimeout(exc) from exc + except (trio.BrokenResourceError, OSError) as exc: + raise ConnectError(exc) from exc return TrioStream(stream) async def sleep(self, seconds: float) -> None: diff --git a/httpcore/_exceptions.py b/httpcore/_exceptions.py index bc28d44f..339c1e1b 100644 --- a/httpcore/_exceptions.py +++ b/httpcore/_exceptions.py @@ -1,20 +1,3 @@ -import contextlib -import typing - -ExceptionMapping = typing.Mapping[typing.Type[Exception], typing.Type[Exception]] - - -@contextlib.contextmanager -def map_exceptions(map: ExceptionMapping) -> typing.Iterator[None]: - try: - yield - except Exception as exc: # noqa: PIE786 - for from_exc, to_exc in map.items(): - if isinstance(exc, from_exc): - raise to_exc(exc) from exc - raise # pragma: nocover - - class ConnectionNotAvailable(Exception): pass diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index ebd3a974..363cb9f3 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -15,7 +15,6 @@ LocalProtocolError, RemoteProtocolError, WriteError, - map_exceptions, ) from .._models import Origin, Request, Response from .._synchronization import Lock, ShieldCancellation @@ -141,12 +140,14 @@ def _send_request_headers(self, request: Request) -> None: timeouts = request.extensions.get("timeout", {}) timeout = timeouts.get("write", None) - with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): + try: event = h11.Request( method=request.method, target=request.url.target, headers=request.headers, ) + except h11.LocalProtocolError as exc: + raise LocalProtocolError(exc) from exc self._send_event(event, timeout=timeout) def _send_request_body(self, request: Request) -> None: @@ -210,8 +211,10 @@ def _receive_event( self, timeout: float | None = None ) -> h11.Event | type[h11.PAUSED]: while True: - with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): + try: event = self._h11_state.next_event() + except h11.RemoteProtocolError as exc: + raise RemoteProtocolError(exc) from exc if event is h11.NEED_DATA: data = self._network_stream.read( diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py index 2ecc9e9c..6e4345ee 100644 --- a/httpcore/_synchronization.py +++ b/httpcore/_synchronization.py @@ -3,7 +3,7 @@ import threading import types -from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions +from ._exceptions import PoolTimeout # Our async synchronization primatives use either 'anyio' or 'trio' depending # on if they're running under asyncio or trio. @@ -139,16 +139,18 @@ async def wait(self, timeout: float | None = None) -> None: self.setup() if self._backend == "trio": - trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} timeout_or_inf = float("inf") if timeout is None else timeout - with map_exceptions(trio_exc_map): + try: with trio.fail_after(timeout_or_inf): await self._trio_event.wait() + except trio.TooSlowError as exc: + raise PoolTimeout(exc) from exc elif self._backend == "asyncio": - anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} - with map_exceptions(anyio_exc_map): + try: with anyio.fail_after(timeout): await self._anyio_event.wait() + except TimeoutError as exc: + raise PoolTimeout(exc) from exc class AsyncSemaphore: