|
79 | 79 | "AsyncStream", |
80 | 80 | "OpenAI", |
81 | 81 | "AsyncOpenAI", |
| 82 | + "BedrockOpenAI", |
| 83 | + "AsyncBedrockOpenAI", |
82 | 84 | "file_from_path", |
83 | 85 | "BaseModel", |
84 | 86 | "DEFAULT_TIMEOUT", |
|
96 | 98 | if not _t.TYPE_CHECKING: |
97 | 99 | from ._utils._resources_proxy import resources as resources |
98 | 100 |
|
99 | | -from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool |
| 101 | +from .lib import azure as _azure, bedrock as _bedrock, pydantic_function_tool as pydantic_function_tool |
100 | 102 | from .version import VERSION as VERSION |
101 | 103 | from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI |
| 104 | +from .lib.bedrock import BedrockOpenAI as BedrockOpenAI, AsyncBedrockOpenAI as AsyncBedrockOpenAI |
102 | 105 | from .lib._old_api import * |
103 | 106 | from .lib.streaming import ( |
104 | 107 | AssistantEventHandler as AssistantEventHandler, |
|
150 | 153 |
|
151 | 154 | http_client: _httpx.Client | None = None |
152 | 155 |
|
153 | | -_ApiType = _te.Literal["openai", "azure"] |
| 156 | +_ApiType = _te.Literal["openai", "azure", "amazon-bedrock"] |
154 | 157 |
|
155 | 158 | api_type: _ApiType | None = _t.cast(_ApiType, _os.environ.get("OPENAI_API_TYPE")) |
156 | 159 |
|
|
162 | 165 |
|
163 | 166 | azure_ad_token_provider: _azure.AzureADTokenProvider | None = None |
164 | 167 |
|
| 168 | +_bedrock_api_key: str | None = None |
| 169 | + |
| 170 | +bedrock_token_provider: _bedrock.BedrockTokenProvider | None = None |
| 171 | + |
165 | 172 |
|
166 | 173 | class _ModuleClient(OpenAI): |
167 | 174 | # Note: we have to use type: ignores here as overriding class members |
@@ -294,10 +301,23 @@ class _AzureModuleClient(_ModuleClient, AzureOpenAI): # type: ignore |
294 | 301 | ... |
295 | 302 |
|
296 | 303 |
|
| 304 | +class _BedrockModuleClient(_ModuleClient, BedrockOpenAI): # type: ignore |
| 305 | + @property # type: ignore |
| 306 | + @override |
| 307 | + def api_key(self) -> str | None: |
| 308 | + return api_key if api_key is not None else _bedrock_api_key |
| 309 | + |
| 310 | + @api_key.setter # type: ignore |
| 311 | + def api_key(self, value: str | None) -> None: # type: ignore |
| 312 | + global _bedrock_api_key |
| 313 | + |
| 314 | + _bedrock_api_key = value |
| 315 | + |
| 316 | + |
297 | 317 | class _AmbiguousModuleClientUsageError(OpenAIError): |
298 | 318 | def __init__(self) -> None: |
299 | 319 | super().__init__( |
300 | | - "Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai` or `azure`" |
| 320 | + "Ambiguous use of module client; please set `openai.api_type` or the `OPENAI_API_TYPE` environment variable to `openai`, `azure`, or `amazon-bedrock`" |
301 | 321 | ) |
302 | 322 |
|
303 | 323 |
|
@@ -370,6 +390,22 @@ def _load_client() -> OpenAI: # type: ignore[reportUnusedFunction] |
370 | 390 | ) |
371 | 391 | return _client |
372 | 392 |
|
| 393 | + if api_type == "amazon-bedrock": |
| 394 | + _client = _BedrockModuleClient( # type: ignore |
| 395 | + api_key=api_key, |
| 396 | + bedrock_token_provider=bedrock_token_provider, |
| 397 | + organization=organization, |
| 398 | + project=project, |
| 399 | + webhook_secret=webhook_secret, |
| 400 | + base_url=base_url, |
| 401 | + timeout=timeout, |
| 402 | + max_retries=max_retries, |
| 403 | + default_headers=default_headers, |
| 404 | + default_query=default_query, |
| 405 | + http_client=http_client, |
| 406 | + ) |
| 407 | + return _client |
| 408 | + |
373 | 409 | _client = _ModuleClient( |
374 | 410 | api_key=api_key, |
375 | 411 | admin_api_key=admin_api_key, |
|
0 commit comments