Skip to content

Commit a50ff0a

Browse files
committed
Fix Bedrock with_options overrides
1 parent 8be32d3 commit a50ff0a

2 files changed

Lines changed: 86 additions & 8 deletions

File tree

src/openai/lib/bedrock.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ def _resolve_bedrock_base_url(base_url: str | httpx.URL | None, aws_region: str
5454
return _normalize_bedrock_base_url(base_url)
5555

5656

57+
def _uses_region_derived_bedrock_base_url(base_url: str | httpx.URL | None) -> bool:
58+
if isinstance(base_url, str) and not base_url.strip():
59+
base_url = None
60+
61+
if base_url is not None:
62+
return False
63+
64+
env_base_url = os.environ.get("AWS_BEDROCK_BASE_URL")
65+
return env_base_url is None or not env_base_url.strip()
66+
67+
5768
def _bedrock_token_provider(provider: BedrockTokenProvider) -> BedrockTokenProvider:
5869
"""Adapt a sync Bedrock token provider to the base client's api_key callback."""
5970

@@ -87,6 +98,7 @@ class BedrockOpenAI(OpenAI):
8798
"""API client for Amazon Bedrock's OpenAI-compatible endpoint."""
8899

89100
_bedrock_token_provider: BedrockTokenProvider | None
101+
_uses_region_derived_base_url: bool
90102
aws_region: str | None
91103

92104
def __init__(
@@ -133,6 +145,7 @@ def __init__(
133145
)
134146

135147
self._bedrock_token_provider = bedrock_token_provider
148+
self._uses_region_derived_base_url = _uses_region_derived_bedrock_base_url(base_url)
136149
self.aws_region = aws_region
137150

138151
super().__init__(
@@ -223,10 +236,17 @@ def copy(
223236
elif set_default_query is not None:
224237
params = set_default_query
225238

226-
next_token_provider = (
227-
bedrock_token_provider if bedrock_token_provider is not None else self._bedrock_token_provider
228-
)
239+
if api_key is not None:
240+
next_token_provider = None
241+
elif bedrock_token_provider is not None:
242+
next_token_provider = bedrock_token_provider
243+
else:
244+
next_token_provider = self._bedrock_token_provider
245+
229246
next_api_key = api_key if api_key is not None else (None if next_token_provider is not None else self.api_key)
247+
next_base_url = base_url
248+
if next_base_url is None and not (aws_region is not None and self._uses_region_derived_base_url):
249+
next_base_url = self.base_url
230250

231251
return self.__class__(
232252
api_key=next_api_key,
@@ -236,7 +256,7 @@ def copy(
236256
project=project if project is not None else self.project,
237257
webhook_secret=webhook_secret if webhook_secret is not None else self.webhook_secret,
238258
websocket_base_url=websocket_base_url if websocket_base_url is not None else self.websocket_base_url,
239-
base_url=base_url if base_url is not None else self.base_url,
259+
base_url=next_base_url,
240260
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
241261
http_client=http_client or self._client,
242262
max_retries=max_retries if is_given(max_retries) else self.max_retries,
@@ -253,6 +273,7 @@ class AsyncBedrockOpenAI(AsyncOpenAI):
253273
"""Async API client for Amazon Bedrock's OpenAI-compatible endpoint."""
254274

255275
_bedrock_token_provider: AsyncBedrockTokenProvider | None
276+
_uses_region_derived_base_url: bool
256277
aws_region: str | None
257278

258279
def __init__(
@@ -299,6 +320,7 @@ def __init__(
299320
)
300321

301322
self._bedrock_token_provider = bedrock_token_provider
323+
self._uses_region_derived_base_url = _uses_region_derived_bedrock_base_url(base_url)
302324
self.aws_region = aws_region
303325

304326
super().__init__(
@@ -391,10 +413,17 @@ def copy(
391413
elif set_default_query is not None:
392414
params = set_default_query
393415

394-
next_token_provider = (
395-
bedrock_token_provider if bedrock_token_provider is not None else self._bedrock_token_provider
396-
)
416+
if api_key is not None:
417+
next_token_provider = None
418+
elif bedrock_token_provider is not None:
419+
next_token_provider = bedrock_token_provider
420+
else:
421+
next_token_provider = self._bedrock_token_provider
422+
397423
next_api_key = api_key if api_key is not None else (None if next_token_provider is not None else self.api_key)
424+
next_base_url = base_url
425+
if next_base_url is None and not (aws_region is not None and self._uses_region_derived_base_url):
426+
next_base_url = self.base_url
398427

399428
return self.__class__(
400429
api_key=next_api_key,
@@ -404,7 +433,7 @@ def copy(
404433
project=project if project is not None else self.project,
405434
webhook_secret=webhook_secret if webhook_secret is not None else self.webhook_secret,
406435
websocket_base_url=websocket_base_url if websocket_base_url is not None else self.websocket_base_url,
407-
base_url=base_url if base_url is not None else self.base_url,
436+
base_url=next_base_url,
408437
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
409438
http_client=http_client or self._client,
410439
max_retries=max_retries if is_given(max_retries) else self.max_retries,

tests/lib/test_bedrock.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,55 @@ def test_preserves_token_provider_across_with_options() -> None:
252252
assert copied_client._refresh_api_key() == "provider token"
253253

254254

255+
@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI])
256+
def test_with_options_api_key_replaces_token_provider(client_cls: type[Client]) -> None:
257+
client = (
258+
make_sync_client(
259+
base_url="https://example.com/openai/v1",
260+
bedrock_token_provider=lambda: "provider token",
261+
)
262+
if client_cls is BedrockOpenAI
263+
else make_async_client(
264+
base_url="https://example.com/openai/v1",
265+
bedrock_token_provider=lambda: "provider token",
266+
)
267+
)
268+
269+
copied_client = client.with_options(api_key="static token")
270+
271+
assert copied_client.api_key == "static token"
272+
assert copied_client._bedrock_token_provider is None
273+
274+
275+
@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI])
276+
def test_with_options_aws_region_recomputes_region_derived_base_url(client_cls: type[Client]) -> None:
277+
with update_env(AWS_BEDROCK_BASE_URL=Omit(), AWS_REGION=Omit(), AWS_DEFAULT_REGION=Omit()):
278+
client = (
279+
make_sync_client(aws_region="us-east-1", api_key="token")
280+
if client_cls is BedrockOpenAI
281+
else make_async_client(aws_region="us-east-1", api_key="token")
282+
)
283+
284+
copied_client = client.with_options(aws_region="eu-west-1")
285+
286+
assert copied_client.aws_region == "eu-west-1"
287+
assert copied_client.base_url == URL("https://bedrock-mantle.eu-west-1.api.aws/openai/v1/")
288+
289+
290+
@pytest.mark.parametrize("client_cls", [BedrockOpenAI, AsyncBedrockOpenAI])
291+
def test_with_options_aws_region_keeps_explicit_base_url(client_cls: type[Client]) -> None:
292+
client = (
293+
make_sync_client(base_url="https://example.com/openai/v1", aws_region="us-east-1", api_key="token")
294+
if client_cls is BedrockOpenAI
295+
else make_async_client(base_url="https://example.com/openai/v1", aws_region="us-east-1", api_key="token")
296+
)
297+
298+
copied_client = client.with_options(aws_region="eu-west-1")
299+
300+
assert copied_client.aws_region == "eu-west-1"
301+
assert copied_client.base_url == URL("https://example.com/openai/v1/")
302+
303+
255304
@pytest.mark.parametrize(
256305
"copy_kwargs",
257306
[

0 commit comments

Comments
 (0)