Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from typing_extensions import Unpack, Literal, override, get_origin

import anyio
import socket
import httpx
import distro
import pydantic
Expand Down Expand Up @@ -831,11 +832,72 @@ def _idempotency_key(self) -> str:
return f"stainless-python-retry-{uuid.uuid4()}"


def _build_keepalive_socket_options() -> list[tuple[int, int, int | bool]]:
"""Build socket options for TCP keepalive.

Enables SO_KEEPALIVE and sets platform-appropriate TCP keepalive
parameters to prevent NAT gateways from silently dropping idle
connections during long-running non-streaming requests.
"""
opts: list[tuple[int, int, int | bool]] = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)]

if hasattr(socket, "TCP_KEEPIDLE"):
# Linux: seconds before sending the first keepalive probe
opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 60))
elif hasattr(socket, "TCP_KEEPALIVE"):
# macOS: seconds before sending the first keepalive probe
opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPALIVE, 60))

if hasattr(socket, "TCP_KEEPINTVL"):
# Seconds between subsequent keepalive probes
opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 60))

if hasattr(socket, "TCP_KEEPCNT"):
# Number of unacknowledged probes before declaring the connection dead
opts.append((socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 5))

return opts


# ``socket_options`` was added to httpx's transports in httpx 0.25.0. This SDK
# still supports ``httpx>=0.23.0``, so detect support at runtime rather than
# raising the dependency floor (which would be a breaking change for users).
_HTTPX_TRANSPORT_SUPPORTS_SOCKET_OPTIONS = (
"socket_options" in inspect.signature(httpx.HTTPTransport.__init__).parameters
)

# ``httpx.Client``/``httpx.AsyncClient`` normally apply these options when they
# build their own default transport. Supplying a pre-built transport bypasses
# that step, so the effective values must be forwarded to avoid silently
# dropping ``DEFAULT_CONNECTION_LIMITS`` and customizations like ``http2=True``.
_TRANSPORT_PASSTHROUGH_KEYS = ("verify", "cert", "trust_env", "http1", "http2", "limits")


def _build_keepalive_transport(
transport_cls: type[httpx.HTTPTransport] | type[httpx.AsyncHTTPTransport],
kwargs: dict[str, Any],
) -> httpx.HTTPTransport | httpx.AsyncHTTPTransport:
"""Build a default transport with TCP keepalive enabled.

The relevant httpx client options are forwarded so that constructing the
transport explicitly does not discard them, and ``socket_options`` is only
passed when the installed httpx version supports it.
"""
transport_kwargs: dict[str, Any] = {
key: kwargs[key] for key in _TRANSPORT_PASSTHROUGH_KEYS if key in kwargs
}
if _HTTPX_TRANSPORT_SUPPORTS_SOCKET_OPTIONS:
transport_kwargs["socket_options"] = _build_keepalive_socket_options()
return transport_cls(**transport_kwargs)


class _DefaultHttpxClient(httpx.Client):
def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
kwargs.setdefault("follow_redirects", True)
if "transport" not in kwargs:
kwargs["transport"] = _build_keepalive_transport(httpx.HTTPTransport, kwargs)
super().__init__(**kwargs)


Expand Down Expand Up @@ -1423,6 +1485,8 @@ def __init__(self, **kwargs: Any) -> None:
kwargs.setdefault("timeout", DEFAULT_TIMEOUT)
kwargs.setdefault("limits", DEFAULT_CONNECTION_LIMITS)
kwargs.setdefault("follow_redirects", True)
if "transport" not in kwargs:
kwargs["transport"] = _build_keepalive_transport(httpx.AsyncHTTPTransport, kwargs)
super().__init__(**kwargs)


Expand Down
173 changes: 173 additions & 0 deletions tests/test_tcp_keepalive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# File generated to supplement coverage for openai-python PR #3368 (TCP keepalive).
#
# These tests verify that the default sync/async HTTP transports are configured
# with TCP keepalive socket options so that long-lived connections survive NAT
# idle timeouts, while ensuring a caller-supplied client/transport is never
# silently replaced (the ``setdefault`` contract).
from __future__ import annotations

import sys
import socket as socket_module
from typing import Any, List, Tuple

import httpx
import pytest

from openai import OpenAI, AsyncOpenAI
from openai._base_client import (
DefaultHttpxClient,
DefaultAsyncHttpxClient,
_build_keepalive_socket_options,
)

base_url = "http://localhost:4010"
api_key = "My API Key"

SocketOption = Tuple[int, int, int]


def _extract_socket_options(transport: Any) -> List[SocketOption]:
"""Pull the configured ``socket_options`` out of an httpx transport.

httpx stores the options on the underlying httpcore connection pool, which
is exposed (privately) as ``transport._pool._socket_options``. We access it
defensively so a missing attribute fails with a clear assertion rather than
an opaque ``AttributeError``.
"""
pool = getattr(transport, "_pool", None)
assert pool is not None, "transport should expose an underlying connection pool"

socket_options = getattr(pool, "_socket_options", None)
assert socket_options is not None, "connection pool should expose _socket_options"

return list(socket_options)


def _has_option(socket_options: List[SocketOption], level: int, optname: int) -> bool:
"""Return True if an option tuple with the given ``(level, optname)`` exists."""
return any(opt[0] == level and opt[1] == optname for opt in socket_options)


def _idle_optnames() -> List[int]:
"""Collect the platform-specific "idle time before keepalive" optnames.

* Linux exposes ``TCP_KEEPIDLE``.
* macOS exposes ``TCP_KEEPALIVE`` (numerically ``0x10``); some Python builds
do not export the constant, so fall back to the raw value on darwin.
"""
optnames: List[int] = []
if hasattr(socket_module, "TCP_KEEPIDLE"):
optnames.append(socket_module.TCP_KEEPIDLE)
if hasattr(socket_module, "TCP_KEEPALIVE"):
optnames.append(socket_module.TCP_KEEPALIVE)
if not optnames and sys.platform == "darwin":
optnames.append(0x10)
return optnames


class TestOpenAI:
def test_default_sync_transport_has_tcp_keepalive(self) -> None:
client = OpenAI(base_url=base_url, api_key=api_key)
transport = client._client._transport

assert isinstance(
transport, httpx.HTTPTransport
), "Default sync client should use a concrete httpx.HTTPTransport"

socket_options = _extract_socket_options(transport)
assert _has_option(
socket_options, socket_module.SOL_SOCKET, socket_module.SO_KEEPALIVE
), "Default sync transport must enable SO_KEEPALIVE to survive NAT idle timeouts"

def test_keepalive_includes_keepidle_or_keepalive(self) -> None:
idle_optnames = _idle_optnames()
if not idle_optnames:
pytest.skip("platform exposes no TCP keepidle/keepalive option name")

socket_options = _build_keepalive_socket_options()
assert any(
opt[1] in idle_optnames for opt in socket_options
), "Keepalive options must set the idle interval (TCP_KEEPIDLE / TCP_KEEPALIVE)"

def test_keepalive_includes_keepintvl(self) -> None:
if not hasattr(socket_module, "TCP_KEEPINTVL"):
pytest.skip("platform does not support TCP_KEEPINTVL")

socket_options = _build_keepalive_socket_options()
assert _has_option(
socket_options, socket_module.IPPROTO_TCP, socket_module.TCP_KEEPINTVL
), "Keepalive options should set the probe interval (TCP_KEEPINTVL) when supported"

def test_keepalive_includes_keepcnt(self) -> None:
if not hasattr(socket_module, "TCP_KEEPCNT"):
pytest.skip("platform does not support TCP_KEEPCNT")

socket_options = _build_keepalive_socket_options()
assert _has_option(
socket_options, socket_module.IPPROTO_TCP, socket_module.TCP_KEEPCNT
), "Keepalive options should set the probe count (TCP_KEEPCNT) when supported"

def test_custom_http_client_transport_not_overridden(self) -> None:
with httpx.Client() as http_client:
client = OpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
assert (
client._client is http_client
), "A caller-supplied http_client must not be replaced"

def test_custom_transport_kwarg_not_overridden(self) -> None:
# Passing an explicit transport must win over the keepalive default that
# DefaultHttpxClient installs via kwargs.setdefault("transport", ...).
custom_transport = httpx.HTTPTransport()
http_client = DefaultHttpxClient(transport=custom_transport)

assert (
http_client._transport is custom_transport
), "An explicit transport= kwarg must not be overridden by the keepalive default"

def test_build_keepalive_socket_options_returns_valid_list(self) -> None:
socket_options = _build_keepalive_socket_options()

assert isinstance(socket_options, list), "helper should return a list"
assert len(socket_options) >= 1, "helper should return at least SO_KEEPALIVE"

for opt in socket_options:
assert isinstance(opt, tuple), f"each option should be a tuple, got {opt!r}"
assert len(opt) == 3, f"each option should be a (level, optname, value) triple, got {opt!r}"
assert all(
isinstance(part, int) for part in opt
), f"every element of an option triple should be an int, got {opt!r}"

assert any(
opt == (socket_module.SOL_SOCKET, socket_module.SO_KEEPALIVE, 1)
for opt in socket_options
), "helper must always enable SO_KEEPALIVE"


class TestAsyncOpenAI:
async def test_default_async_transport_has_tcp_keepalive(self) -> None:
client = AsyncOpenAI(base_url=base_url, api_key=api_key)
transport = client._client._transport

assert isinstance(
transport, httpx.AsyncHTTPTransport
), "Default async client should use a concrete httpx.AsyncHTTPTransport"

socket_options = _extract_socket_options(transport)
assert _has_option(
socket_options, socket_module.SOL_SOCKET, socket_module.SO_KEEPALIVE
), "Default async transport must enable SO_KEEPALIVE to survive NAT idle timeouts"

async def test_custom_async_http_client_transport_not_overridden(self) -> None:
async with httpx.AsyncClient() as http_client:
client = AsyncOpenAI(base_url=base_url, api_key=api_key, http_client=http_client)
assert (
client._client is http_client
), "A caller-supplied async http_client must not be replaced"

async def test_custom_async_transport_kwarg_not_overridden(self) -> None:
custom_transport = httpx.AsyncHTTPTransport()
http_client = DefaultAsyncHttpxClient(transport=custom_transport)

assert (
http_client._transport is custom_transport
), "An explicit transport= kwarg must not be overridden by the keepalive default"