diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 216b36aabd..ddeb2c61ba 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -33,6 +33,7 @@ from typing_extensions import Unpack, Literal, override, get_origin import anyio +import socket import httpx import distro import pydantic @@ -831,11 +832,72 @@ def _idempotency_key(self) -> str: return f"stainless-python-retry-{uuid.uuid4()}" +def _build_keepalive_socket_options() -> list[tuple[int, int, int | bool]]: + """Build socket options for TCP keepalive. + + Enables SO_KEEPALIVE and sets platform-appropriate TCP keepalive + parameters to prevent NAT gateways from silently dropping idle + connections during long-running non-streaming requests. + """ + opts: list[tuple[int, int, int | bool]] = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] + + if hasattr(socket, "TCP_KEEPIDLE"): + # Linux: seconds before sending the first keepalive probe + opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60)) + elif hasattr(socket, "TCP_KEEPALIVE"): + # macOS: seconds before sending the first keepalive probe + opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 60)) + + if hasattr(socket, "TCP_KEEPINTVL"): + # Seconds between subsequent keepalive probes + opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60)) + + if hasattr(socket, "TCP_KEEPCNT"): + # Number of unacknowledged probes before declaring the connection dead + opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5)) + + return opts + + +# ``socket_options`` was added to httpx's transports in httpx 0.25.0. This SDK +# still supports ``httpx>=0.23.0``, so detect support at runtime rather than +# raising the dependency floor (which would be a breaking change for users). +_HTTPX_TRANSPORT_SUPPORTS_SOCKET_OPTIONS = ( + "socket_options" in inspect.signature(httpx.HTTPTransport.__init__).parameters +) + +# ``httpx.Client``/``httpx.AsyncClient`` normally apply these options when they +# build their own default transport. Supplying a pre-built transport bypasses +# that step, so the effective values must be forwarded to avoid silently +# dropping ``DEFAULT_CONNECTION_LIMITS`` and customizations like ``http2=True``. +_TRANSPORT_PASSTHROUGH_KEYS = ("verify", "cert", "trust_env", "http1", "http2", "limits") + + +def _build_keepalive_transport( + transport_cls: type[httpx.HTTPTransport] | type[httpx.AsyncHTTPTransport], + kwargs: dict[str, Any], +) -> httpx.HTTPTransport | httpx.AsyncHTTPTransport: + """Build a default transport with TCP keepalive enabled. + + The relevant httpx client options are forwarded so that constructing the + transport explicitly does not discard them, and ``socket_options`` is only + passed when the installed httpx version supports it. + """ + transport_kwargs: dict[str, Any] = { + key: kwargs[key] for key in _TRANSPORT_PASSTHROUGH_KEYS if key in kwargs + } + if _HTTPX_TRANSPORT_SUPPORTS_SOCKET_OPTIONS: + transport_kwargs["socket_options"] = _build_keepalive_socket_options() + return transport_cls(**transport_kwargs) + + class _DefaultHttpxClient(httpx.Client): def __init__(self, **kwargs: Any) -> None: kwargs.setdefault("timeout", DEFAULT_TIMEOUT) kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) kwargs.setdefault("follow_redirects", True) + if "transport" not in kwargs: + kwargs["transport"] = _build_keepalive_transport(httpx.HTTPTransport, kwargs) super().__init__(**kwargs) @@ -1423,6 +1485,8 @@ def __init__(self, **kwargs: Any) -> None: kwargs.setdefault("timeout", DEFAULT_TIMEOUT) kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS) kwargs.setdefault("follow_redirects", True) + if "transport" not in kwargs: + kwargs["transport"] = _build_keepalive_transport(httpx.AsyncHTTPTransport, kwargs) super().__init__(**kwargs) diff --git a/tests/test_tcp_keepalive.py b/tests/test_tcp_keepalive.py new file mode 100644 index 0000000000..416ce9fd4d --- /dev/null +++ b/tests/test_tcp_keepalive.py @@ -0,0 +1,173 @@ +# File generated to supplement coverage for openai-python PR #3368 (TCP keepalive). +# +# These tests verify that the default sync/async HTTP transports are configured +# with TCP keepalive socket options so that long-lived connections survive NAT +# idle timeouts, while ensuring a caller-supplied client/transport is never +# silently replaced (the ``setdefault`` contract). +from __future__ import annotations + +import sys +import socket as socket_module +from typing import Any, List, Tuple + +import httpx +import pytest + +from openai import OpenAI, AsyncOpenAI +from openai._base_client import ( + DefaultHttpxClient, + DefaultAsyncHttpxClient, + _build_keepalive_socket_options, +) + +base_url = "http://localhost:4010" +api_key = "My API Key" + +SocketOption = Tuple[int, int, int] + + +def _extract_socket_options(transport: Any) -> List[SocketOption]: + """Pull the configured ``socket_options`` out of an httpx transport. + + httpx stores the options on the underlying httpcore connection pool, which + is exposed (privately) as ``transport._pool._socket_options``. We access it + defensively so a missing attribute fails with a clear assertion rather than + an opaque ``AttributeError``. + """ + pool = getattr(transport, "_pool", None) + assert pool is not None, "transport should expose an underlying connection pool" + + socket_options = getattr(pool, "_socket_options", None) + assert socket_options is not None, "connection pool should expose _socket_options" + + return list(socket_options) + + +def _has_option(socket_options: List[SocketOption], level: int, optname: int) -> bool: + """Return True if an option tuple with the given ``(level, optname)`` exists.""" + return any(opt[0] == level and opt[1] == optname for opt in socket_options) + + +def _idle_optnames() -> List[int]: + """Collect the platform-specific "idle time before keepalive" optnames. + + * Linux exposes ``TCP_KEEPIDLE``. + * macOS exposes ``TCP_KEEPALIVE`` (numerically ``0x10``); some Python builds + do not export the constant, so fall back to the raw value on darwin. + """ + optnames: List[int] = [] + if hasattr(socket_module, "TCP_KEEPIDLE"): + optnames.append(socket_module.TCP_KEEPIDLE) + if hasattr(socket_module, "TCP_KEEPALIVE"): + optnames.append(socket_module.TCP_KEEPALIVE) + if not optnames and sys.platform == "darwin": + optnames.append(0x10) + return optnames + + +class TestOpenAI: + def test_default_sync_transport_has_tcp_keepalive(self) -> None: + client = OpenAI(base_url=base_url, api_key=api_key) + transport = client._client._transport + + assert isinstance( + transport, httpx.HTTPTransport + ), "Default sync client should use a concrete httpx.HTTPTransport" + + socket_options = _extract_socket_options(transport) + assert _has_option( + socket_options, socket_module.SOL_SOCKET, socket_module.SO_KEEPALIVE + ), "Default sync transport must enable SO_KEEPALIVE to survive NAT idle timeouts" + + def test_keepalive_includes_keepidle_or_keepalive(self) -> None: + idle_optnames = _idle_optnames() + if not idle_optnames: + pytest.skip("platform exposes no TCP keepidle/keepalive option name") + + socket_options = _build_keepalive_socket_options() + assert any( + opt[1] in idle_optnames for opt in socket_options + ), "Keepalive options must set the idle interval (TCP_KEEPIDLE / TCP_KEEPALIVE)" + + def test_keepalive_includes_keepintvl(self) -> None: + if not hasattr(socket_module, "TCP_KEEPINTVL"): + pytest.skip("platform does not support TCP_KEEPINTVL") + + socket_options = _build_keepalive_socket_options() + assert _has_option( + socket_options, socket_module.IPPROTO_TCP, socket_module.TCP_KEEPINTVL + ), "Keepalive options should set the probe interval (TCP_KEEPINTVL) when supported" + + def test_keepalive_includes_keepcnt(self) -> None: + if not hasattr(socket_module, "TCP_KEEPCNT"): + pytest.skip("platform does not support TCP_KEEPCNT") + + socket_options = _build_keepalive_socket_options() + assert _has_option( + socket_options, socket_module.IPPROTO_TCP, socket_module.TCP_KEEPCNT + ), "Keepalive options should set the probe count (TCP_KEEPCNT) when supported" + + def test_custom_http_client_transport_not_overridden(self) -> None: + with httpx.Client() as http_client: + client = OpenAI(base_url=base_url, api_key=api_key, http_client=http_client) + assert ( + client._client is http_client + ), "A caller-supplied http_client must not be replaced" + + def test_custom_transport_kwarg_not_overridden(self) -> None: + # Passing an explicit transport must win over the keepalive default that + # DefaultHttpxClient installs via kwargs.setdefault("transport", ...). + custom_transport = httpx.HTTPTransport() + http_client = DefaultHttpxClient(transport=custom_transport) + + assert ( + http_client._transport is custom_transport + ), "An explicit transport= kwarg must not be overridden by the keepalive default" + + def test_build_keepalive_socket_options_returns_valid_list(self) -> None: + socket_options = _build_keepalive_socket_options() + + assert isinstance(socket_options, list), "helper should return a list" + assert len(socket_options) >= 1, "helper should return at least SO_KEEPALIVE" + + for opt in socket_options: + assert isinstance(opt, tuple), f"each option should be a tuple, got {opt!r}" + assert len(opt) == 3, f"each option should be a (level, optname, value) triple, got {opt!r}" + assert all( + isinstance(part, int) for part in opt + ), f"every element of an option triple should be an int, got {opt!r}" + + assert any( + opt == (socket_module.SOL_SOCKET, socket_module.SO_KEEPALIVE, 1) + for opt in socket_options + ), "helper must always enable SO_KEEPALIVE" + + +class TestAsyncOpenAI: + async def test_default_async_transport_has_tcp_keepalive(self) -> None: + client = AsyncOpenAI(base_url=base_url, api_key=api_key) + transport = client._client._transport + + assert isinstance( + transport, httpx.AsyncHTTPTransport + ), "Default async client should use a concrete httpx.AsyncHTTPTransport" + + socket_options = _extract_socket_options(transport) + assert _has_option( + socket_options, socket_module.SOL_SOCKET, socket_module.SO_KEEPALIVE + ), "Default async transport must enable SO_KEEPALIVE to survive NAT idle timeouts" + + async def test_custom_async_http_client_transport_not_overridden(self) -> None: + async with httpx.AsyncClient() as http_client: + client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client) + assert ( + client._client is http_client + ), "A caller-supplied async http_client must not be replaced" + + async def test_custom_async_transport_kwarg_not_overridden(self) -> None: + custom_transport = httpx.AsyncHTTPTransport() + http_client = DefaultAsyncHttpxClient(transport=custom_transport) + + assert ( + http_client._transport is custom_transport + ), "An explicit transport= kwarg must not be overridden by the keepalive default"