diff --git a/openhands-sdk/openhands/sdk/llm/__init__.py b/openhands-sdk/openhands/sdk/llm/__init__.py index fabed357d1..a5284acf39 100644 --- a/openhands-sdk/openhands/sdk/llm/__init__.py +++ b/openhands-sdk/openhands/sdk/llm/__init__.py @@ -1,6 +1,7 @@ from openhands.sdk.llm.llm import LLM from openhands.sdk.llm.llm_registry import LLMRegistry, RegistryEvent from openhands.sdk.llm.llm_response import LLMResponse +from openhands.sdk.llm.llm_with_gateway import LLMWithGateway from openhands.sdk.llm.message import ( ImageContent, Message, @@ -23,6 +24,7 @@ __all__ = [ "LLMResponse", "LLM", + "LLMWithGateway", "LLMRegistry", "RouterLLM", "RegistryEvent", diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 1972a2005f..ded2e62639 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -162,6 +162,14 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): ) ollama_base_url: str | None = Field(default=None) + ssl_verify: bool | str | None = Field( + default=None, + description=( + "TLS verification forwarded to LiteLLM; " + "set to False when corporate proxies break certificate chains." + ), + ) + drop_params: bool = Field(default=True) modify_params: bool = Field( default=True, @@ -446,15 +454,19 @@ def completion( has_tools_flag = bool(cc_tools) and use_native_fc # Behavior-preserving: delegate to select_chat_options call_kwargs = select_chat_options(self, kwargs, has_tools=has_tools_flag) + call_kwargs = self._prepare_request_kwargs(call_kwargs) # 4) optional request logging context (kept small) assert self._telemetry is not None log_ctx = None if self._telemetry.log_enabled: + sanitized_kwargs = { + k: v for k, v in call_kwargs.items() if k != "extra_headers" + } log_ctx = { "messages": formatted_messages[:], # already simple dicts "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, + "kwargs": sanitized_kwargs, "context_window": self.max_input_tokens or 0, } if tools and not use_native_fc: @@ -473,7 +485,7 @@ def completion( def _one_attempt(**retry_kwargs) -> ModelResponse: assert self._telemetry is not None # Merge retry-modified kwargs (like temperature) with call_kwargs - final_kwargs = {**call_kwargs, **retry_kwargs} + final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs}) resp = self._transport_call(messages=formatted_messages, **final_kwargs) raw_resp: ModelResponse | None = None if use_mock_tools: @@ -557,16 +569,20 @@ def responses( call_kwargs = select_responses_options( self, kwargs, include=include, store=store ) + call_kwargs = self._prepare_request_kwargs(call_kwargs) # Optional request logging assert self._telemetry is not None log_ctx = None if self._telemetry.log_enabled: + sanitized_kwargs = { + k: v for k, v in call_kwargs.items() if k != "extra_headers" + } log_ctx = { "llm_path": "responses", "input": input_items[:], "tools": tools, - "kwargs": {k: v for k, v in call_kwargs.items()}, + "kwargs": sanitized_kwargs, "context_window": self.max_input_tokens or 0, } self._telemetry.on_request(log_ctx=log_ctx) @@ -581,7 +597,7 @@ def responses( retry_listener=self.retry_listener, ) def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: - final_kwargs = {**call_kwargs, **retry_kwargs} + final_kwargs = self._prepare_request_kwargs({**call_kwargs, **retry_kwargs}) with self._litellm_modify_params_ctx(self.modify_params): with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -598,7 +614,9 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: else None, api_base=self.base_url, api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, timeout=self.timeout, + ssl_verify=self.ssl_verify, drop_params=self.drop_params, seed=self.seed, **final_kwargs, @@ -606,7 +624,7 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: assert isinstance(ret, ResponsesAPIResponse), ( f"Expected ResponsesAPIResponse, got {type(ret)}" ) - # telemetry (latency, cost). Token usage mapping we handle after. + # telemetry (latency, cost). Token usage handled after. assert self._telemetry is not None self._telemetry.on_response(ret) return ret @@ -637,6 +655,11 @@ def _one_attempt(**retry_kwargs) -> ResponsesAPIResponse: # ========================================================================= # Transport + helpers # ========================================================================= + def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + """Hook for subclasses to adjust final LiteLLM kwargs.""" + + return call_kwargs + def _transport_call( self, *, messages: list[dict[str, Any]], **kwargs ) -> ModelResponse: @@ -666,7 +689,9 @@ def _transport_call( api_key=self.api_key.get_secret_value() if self.api_key else None, base_url=self.base_url, api_version=self.api_version, + custom_llm_provider=self.custom_llm_provider, timeout=self.timeout, + ssl_verify=self.ssl_verify, drop_params=self.drop_params, seed=self.seed, messages=messages, @@ -928,6 +953,7 @@ def load_from_json(cls, json_path: str) -> LLM: @classmethod def load_from_env(cls, prefix: str = "LLM_") -> LLM: TRUTHY = {"true", "1", "yes", "on"} + FALSY = {"false", "0", "no", "off"} def _unwrap_type(t: Any) -> Any: origin = get_origin(t) @@ -936,20 +962,33 @@ def _unwrap_type(t: Any) -> Any: args = [a for a in get_args(t) if a is not type(None)] return args[0] if args else t - def _cast_value(raw: str, t: Any) -> Any: - t = _unwrap_type(t) + def _cast_value(field_name: str, raw: str, annotation: Any) -> Any: + stripped = raw.strip() + lowered = stripped.lower() + if field_name == "ssl_verify": + if lowered in TRUTHY: + return True + if lowered in FALSY: + return False + return stripped + + t = _unwrap_type(annotation) if t is SecretStr: - return SecretStr(raw) + return SecretStr(stripped) if t is bool: - return raw.lower() in TRUTHY + if lowered in TRUTHY: + return True + if lowered in FALSY: + return False + return None if t is int: try: - return int(raw) + return int(stripped) except ValueError: return None if t is float: try: - return float(raw) + return float(stripped) except ValueError: return None origin = get_origin(t) @@ -957,10 +996,10 @@ def _cast_value(raw: str, t: Any) -> Any: isinstance(t, type) and issubclass(t, BaseModel) ): try: - return json.loads(raw) + return json.loads(stripped) except Exception: pass - return raw + return stripped data: dict[str, Any] = {} fields: dict[str, Any] = { @@ -975,7 +1014,7 @@ def _cast_value(raw: str, t: Any) -> Any: field_name = key[len(prefix) :].lower() if field_name not in fields: continue - v = _cast_value(value, fields[field_name]) + v = _cast_value(field_name, value, fields[field_name]) if v is not None: data[field_name] = v return cls(**data) diff --git a/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py new file mode 100644 index 0000000000..bdbde418a3 --- /dev/null +++ b/openhands-sdk/openhands/sdk/llm/llm_with_gateway.py @@ -0,0 +1,91 @@ +"""LLM subclass with enterprise gateway support. + +This module provides LLMWithGateway, which extends the base LLM class to support +custom headers for enterprise API gateways. +""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from pydantic import Field + +from openhands.sdk.llm.llm import LLM +from openhands.sdk.logger import get_logger + + +__all__ = ["LLMWithGateway"] + + +logger = get_logger(__name__) + + +class LLMWithGateway(LLM): + """LLM subclass with enterprise gateway support. + + Supports adding static custom headers on each request. Take care not to include + raw secrets in headers unless the gateway is trusted and headers are never logged. + """ + + custom_headers: dict[str, str] | None = Field( + default=None, + description="Custom headers to include with every LLM request.", + ) + + def _prepare_request_kwargs(self, call_kwargs: dict[str, Any]) -> dict[str, Any]: + prepared = dict(super()._prepare_request_kwargs(call_kwargs)) + + if not self.custom_headers: + return prepared + + existing = prepared.get("extra_headers") + base_headers: dict[str, Any] + if isinstance(existing, Mapping): + base_headers = dict(existing) + elif existing is None: + base_headers = {} + else: + base_headers = {} + + merged, collisions = self._merge_headers(base_headers, self.custom_headers) + for header, old_val, new_val in collisions: + logger.warning( + "LLMWithGateway overriding header %s (existing=%r, new=%r)", + header, + old_val, + new_val, + ) + + if merged: + prepared["extra_headers"] = merged + + return prepared + + @staticmethod + def _merge_headers( + existing: dict[str, Any], new_headers: dict[str, Any] + ) -> tuple[dict[str, Any], list[tuple[str, Any, Any]]]: + """Merge header dictionaries case-insensitively. + + Returns the merged headers and a list of collisions where an existing + header was replaced with a different value. + """ + + merged = dict(existing) + lower_map = {k.lower(): k for k in merged} + collisions: list[tuple[str, Any, Any]] = [] + + for key, value in new_headers.items(): + lower = key.lower() + if lower in lower_map: + canonical = lower_map[lower] + old_value = merged[canonical] + if old_value != value: + collisions.append((canonical, old_value, value)) + merged[canonical] = value + else: + merged[key] = value + lower_map[lower] = key + + return merged, collisions diff --git a/tests/sdk/llm/test_llm_with_gateway.py b/tests/sdk/llm/test_llm_with_gateway.py new file mode 100644 index 0000000000..94449c5f02 --- /dev/null +++ b/tests/sdk/llm/test_llm_with_gateway.py @@ -0,0 +1,110 @@ +"""Tests for LLMWithGateway custom header support.""" + +from __future__ import annotations + +from unittest.mock import patch + +from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse +from openai.types.responses.response_output_message import ResponseOutputMessage +from openai.types.responses.response_output_text import ResponseOutputText +from pydantic import SecretStr + +from openhands.sdk.llm import LLMWithGateway, Message, TextContent +from tests.conftest import create_mock_litellm_response + + +def create_llm(custom_headers: dict[str, str] | None = None) -> LLMWithGateway: + """Helper to build an LLMWithGateway for tests.""" + return LLMWithGateway( + model="gemini-1.5-flash", + api_key=SecretStr("test-api-key"), + base_url="https://gateway.example.com/v1", + custom_headers=custom_headers, + usage_id="gateway-test-llm", + num_retries=0, + ) + + +def make_responses_api_response(text: str) -> ResponsesAPIResponse: + """Construct a minimal ResponsesAPIResponse for testing.""" + + message = ResponseOutputMessage.model_construct( + id="msg", + type="message", + role="assistant", + status="completed", + content=[ # type: ignore[arg-type] + ResponseOutputText(type="output_text", text=text, annotations=[]) + ], + ) + + usage = ResponseAPIUsage(input_tokens=1, output_tokens=1, total_tokens=2) + + return ResponsesAPIResponse( + id="resp", + created_at=0, + output=[message], # type: ignore[arg-type] + parallel_tool_calls=False, + tool_choice="auto", + top_p=None, + tools=[], + usage=usage, + instructions=None, + status="completed", + ) + + +class TestInitialization: + """Basic initialization behaviour.""" + + def test_defaults(self) -> None: + llm = create_llm() + assert llm.custom_headers is None + + def test_custom_headers_configuration(self) -> None: + headers = {"X-Custom-Key": "value"} + llm = create_llm(custom_headers=headers) + assert llm.custom_headers == headers + + +class TestHeaderInjection: + """Ensure custom headers are merged into completion calls.""" + + @patch("openhands.sdk.llm.llm.litellm_completion") + def test_headers_passed_to_litellm(self, mock_completion) -> None: + llm = create_llm(custom_headers={"X-Test": "value"}) + mock_completion.return_value = create_mock_litellm_response(content="Hello!") + + messages = [Message(role="user", content=[TextContent(text="Hi")])] + response = llm.completion(messages) + + mock_completion.assert_called_once() + headers = mock_completion.call_args.kwargs["extra_headers"] + assert headers["X-Test"] == "value" + + # Ensure we still surface the underlying content. + content = response.message.content[0] + assert isinstance(content, TextContent) + assert content.text == "Hello!" + + @patch("openhands.sdk.llm.llm.litellm_completion") + def test_headers_merge_existing_extra_headers(self, mock_completion) -> None: + llm = create_llm(custom_headers={"X-Test": "value"}) + mock_completion.return_value = create_mock_litellm_response(content="Merged!") + + messages = [Message(role="user", content=[TextContent(text="Hi")])] + llm.completion(messages, extra_headers={"Existing": "1"}) + + headers = mock_completion.call_args.kwargs["extra_headers"] + assert headers["X-Test"] == "value" + assert headers["Existing"] == "1" + + @patch("openhands.sdk.llm.llm.litellm_responses") + def test_responses_headers_passed_to_litellm(self, mock_responses) -> None: + llm = create_llm(custom_headers={"X-Test": "value"}) + mock_responses.return_value = make_responses_api_response("ok") + + llm.responses([Message(role="user", content=[TextContent(text="Hi")])]) + + headers = mock_responses.call_args.kwargs["extra_headers"] + assert headers["X-Test"] == "value"