Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions httpcore/_async/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
LocalProtocolError,
RemoteProtocolError,
WriteError,
map_exceptions,
)
from .._models import Origin, Request, Response
from .._synchronization import AsyncLock, AsyncShieldCancellation
Expand Down Expand Up @@ -141,12 +140,14 @@ async def _send_request_headers(self, request: Request) -> None:
timeouts = request.extensions.get("timeout", {})
timeout = timeouts.get("write", None)

with map_exceptions({h11.LocalProtocolError: LocalProtocolError}):
try:
event = h11.Request(
method=request.method,
target=request.url.target,
headers=request.headers,
)
except h11.LocalProtocolError as exc:
raise LocalProtocolError(exc) from exc
await self._send_event(event, timeout=timeout)

async def _send_request_body(self, request: Request) -> None:
Expand Down Expand Up @@ -210,8 +211,10 @@ async def _receive_event(
self, timeout: float | None = None
) -> h11.Event | type[h11.PAUSED]:
while True:
with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}):
try:
event = self._h11_state.next_event()
except h11.RemoteProtocolError as exc:
raise RemoteProtocolError(exc) from exc

if event is h11.NEED_DATA:
data = await self._network_stream.read(
Expand Down
62 changes: 29 additions & 33 deletions httpcore/_backends/anyio.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, AsyncNetworkBackend, AsyncNetworkStream
Expand All @@ -23,31 +22,32 @@ def __init__(self, stream: anyio.abc.ByteStream) -> None:
self._stream = stream

async def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map = {
TimeoutError: ReadTimeout,
anyio.BrokenResourceError: ReadError,
anyio.ClosedResourceError: ReadError,
anyio.EndOfStream: ReadError,
}
with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
try:
return await self._stream.receive(max_bytes=max_bytes)
except anyio.EndOfStream: # pragma: nocover
return b""
except TimeoutError as exc:
raise ReadTimeout(exc) from exc
except (
anyio.BrokenResourceError,
anyio.ClosedResourceError,
anyio.EndOfStream,
) as exc:
raise ReadError(exc) from exc

async def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return

exc_map = {
TimeoutError: WriteTimeout,
anyio.BrokenResourceError: WriteError,
anyio.ClosedResourceError: WriteError,
}
with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
await self._stream.send(item=buffer)
except TimeoutError as exc:
raise WriteTimeout(exc) from exc
except (anyio.BrokenResourceError, anyio.ClosedResourceError) as exc:
raise WriteError(exc) from exc

async def aclose(self) -> None:
await self._stream.aclose()
Expand All @@ -58,13 +58,7 @@ async def start_tls(
server_hostname: str | None = None,
timeout: float | None = None,
) -> AsyncNetworkStream:
exc_map = {
TimeoutError: ConnectTimeout,
anyio.BrokenResourceError: ConnectError,
anyio.EndOfStream: ConnectError,
ssl.SSLError: ConnectError,
}
with map_exceptions(exc_map):
try:
try:
with anyio.fail_after(timeout):
ssl_stream = await anyio.streams.tls.TLSStream.wrap(
Expand All @@ -77,6 +71,10 @@ async def start_tls(
except Exception as exc: # pragma: nocover
await self.aclose()
raise exc
except TimeoutError as exc:
raise ConnectTimeout(exc) from exc
except (anyio.BrokenResourceError, anyio.EndOfStream, ssl.SSLError) as exc:
raise ConnectError(exc) from exc
return AnyIOStream(ssl_stream)

def get_extra_info(self, info: str) -> typing.Any:
Expand Down Expand Up @@ -105,12 +103,7 @@ async def connect_tcp(
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_tcp(
remote_host=host,
Expand All @@ -120,6 +113,10 @@ async def connect_tcp(
# By default TCP sockets opened in `asyncio` include TCP_NODELAY.
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
except TimeoutError as exc:
raise ConnectTimeout(exc) from exc
except (OSError, anyio.BrokenResourceError) as exc:
raise ConnectError(exc) from exc
return AnyIOStream(stream)

async def connect_unix_socket(
Expand All @@ -130,16 +127,15 @@ async def connect_unix_socket(
) -> AsyncNetworkStream: # pragma: nocover
if socket_options is None:
socket_options = []
exc_map = {
TimeoutError: ConnectTimeout,
OSError: ConnectError,
anyio.BrokenResourceError: ConnectError,
}
with map_exceptions(exc_map):
try:
with anyio.fail_after(timeout):
stream: anyio.abc.ByteStream = await anyio.connect_unix(path)
for option in socket_options:
stream._raw_socket.setsockopt(*option) # type: ignore[attr-defined] # pragma: no cover
except TimeoutError as exc:
raise ConnectTimeout(exc) from exc
except (OSError, anyio.BrokenResourceError) as exc:
raise ConnectError(exc) from exc
return AnyIOStream(stream)

async def sleep(self, seconds: float) -> None:
Expand Down
60 changes: 35 additions & 25 deletions httpcore/_backends/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@
from .._exceptions import (
ConnectError,
ConnectTimeout,
ExceptionMapping,
ReadError,
ReadTimeout,
WriteError,
WriteTimeout,
map_exceptions,
)
from .._utils import is_socket_readable
from .base import SOCKET_OPTION, NetworkBackend, NetworkStream
Expand Down Expand Up @@ -77,20 +75,26 @@ def _perform_io(
return ret

def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
return typing.cast(
bytes, self._perform_io(functools.partial(self.ssl_obj.read, max_bytes))
)
except socket.timeout as exc:
raise ReadTimeout(exc) from exc
except OSError as exc:
raise ReadError(exc) from exc

def write(self, buffer: bytes, timeout: float | None = None) -> None:
exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
while buffer:
nsent = self._perform_io(functools.partial(self.ssl_obj.write, buffer))
buffer = buffer[nsent:]
except socket.timeout as exc:
raise WriteTimeout(exc) from exc
except OSError as exc:
raise WriteError(exc) from exc

def close(self) -> None:
self._sock.close()
Expand Down Expand Up @@ -122,21 +126,27 @@ def __init__(self, sock: socket.socket) -> None:
self._sock = sock

def read(self, max_bytes: int, timeout: float | None = None) -> bytes:
exc_map: ExceptionMapping = {socket.timeout: ReadTimeout, OSError: ReadError}
with map_exceptions(exc_map):
try:
self._sock.settimeout(timeout)
return self._sock.recv(max_bytes)
except socket.timeout as exc:
raise ReadTimeout(exc) from exc
except OSError as exc:
raise ReadError(exc) from exc

def write(self, buffer: bytes, timeout: float | None = None) -> None:
if not buffer:
return

exc_map: ExceptionMapping = {socket.timeout: WriteTimeout, OSError: WriteError}
with map_exceptions(exc_map):
try:
while buffer:
self._sock.settimeout(timeout)
n = self._sock.send(buffer)
buffer = buffer[n:]
except socket.timeout as exc:
raise WriteTimeout(exc) from exc
except OSError as exc:
raise WriteError(exc) from exc

def close(self) -> None:
self._sock.close()
Expand All @@ -147,11 +157,7 @@ def start_tls(
server_hostname: str | None = None,
timeout: float | None = None,
) -> NetworkStream:
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
try:
try:
if isinstance(self._sock, ssl.SSLSocket): # pragma: no cover
# If the underlying socket has already been upgraded
Expand All @@ -168,6 +174,10 @@ def start_tls(
except Exception as exc: # pragma: nocover
self.close()
raise exc
except socket.timeout as exc:
raise ConnectTimeout(exc) from exc
except OSError as exc:
raise ConnectError(exc) from exc
return SyncStream(sock)

def get_extra_info(self, info: str) -> typing.Any:
Expand Down Expand Up @@ -199,12 +209,8 @@ def connect_tcp(
socket_options = [] # pragma: no cover
address = (host, port)
source_address = None if local_address is None else (local_address, 0)
exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}

with map_exceptions(exc_map):
try:
sock = socket.create_connection(
address,
timeout,
Expand All @@ -213,6 +219,10 @@ def connect_tcp(
for option in socket_options:
sock.setsockopt(*option) # pragma: no cover
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
except socket.timeout as exc:
raise ConnectTimeout(exc) from exc
except OSError as exc:
raise ConnectError(exc) from exc
return SyncStream(sock)

def connect_unix_socket(
Expand All @@ -228,14 +238,14 @@ def connect_unix_socket(
if socket_options is None:
socket_options = []

exc_map: ExceptionMapping = {
socket.timeout: ConnectTimeout,
OSError: ConnectError,
}
with map_exceptions(exc_map):
try:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
for option in socket_options:
sock.setsockopt(*option)
sock.settimeout(timeout)
sock.connect(path)
except socket.timeout as exc:
raise ConnectTimeout(exc) from exc
except OSError as exc:
raise ConnectError(exc) from exc
return SyncStream(sock)
Loading
Loading