Skip to content

Bearer token provider addition to openaiclient #2470

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 46 additions & 12 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
from __future__ import annotations

import os
from typing import TYPE_CHECKING, Any, Union, Mapping
from typing import TYPE_CHECKING, Any, Union, Mapping, Callable, Awaitable
from typing_extensions import Self, override

import httpx

from openai._models import FinalRequestOptions

from . import _exceptions
from ._qs import Querystring
from ._types import (
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(
self,
*,
api_key: str | None = None,
bearer_token_provider: Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
Expand Down Expand Up @@ -120,13 +123,16 @@ def __init__(
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
"""
if api_key and bearer_token_provider:
raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
if api_key is None and bearer_token_provider is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
self.bearer_token_provider = bearer_token_provider
self.api_key = api_key or ''

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -155,6 +161,7 @@ def __init__(
)

self._default_stream_cls = Stream
self._auth_headers: dict[str, str] = {}

@cached_property
def completions(self) -> Completions:
Expand Down Expand Up @@ -259,18 +266,26 @@ def with_raw_response(self) -> OpenAIWithRawResponse:
@cached_property
def with_streaming_response(self) -> OpenAIWithStreamedResponse:
return OpenAIWithStreamedResponse(self)

@property
@override
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

def refresh_auth_headers(self):
bearer_token = self.bearer_token_provider() if self.bearer_token_provider else self.api_key
self._auth_headers = {"Authorization": f"Bearer {bearer_token}"}


@override
def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
self.refresh_auth_headers()
return super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"Bearer {api_key}"}

return self._auth_headers

@property
@override
def default_headers(self) -> dict[str, str | Omit]:
Expand All @@ -286,6 +301,7 @@ def copy(
self,
*,
api_key: str | None = None,
bearer_token_provider: Callable[[], str] | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
Expand Down Expand Up @@ -323,6 +339,7 @@ def copy(
http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
bearer_token_provider = bearer_token_provider or self.bearer_token_provider,
organization=organization or self.organization,
project=project or self.project,
websocket_base_url=websocket_base_url or self.websocket_base_url,
Expand Down Expand Up @@ -392,6 +409,7 @@ def __init__(
self,
*,
api_key: str | None = None,
bearer_token_provider: Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
base_url: str | httpx.URL | None = None,
Expand Down Expand Up @@ -421,13 +439,16 @@ def __init__(
- `organization` from `OPENAI_ORG_ID`
- `project` from `OPENAI_PROJECT_ID`
"""
if api_key and bearer_token_provider:
raise ValueError("The `api_key` and `bearer_token_provider` arguments are mutually exclusive")
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
if api_key is None and bearer_token_provider is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
self.bearer_token_provider = bearer_token_provider
self.api_key = api_key or ''

if organization is None:
organization = os.environ.get("OPENAI_ORG_ID")
Expand Down Expand Up @@ -456,6 +477,7 @@ def __init__(
)

self._default_stream_cls = AsyncStream
self._auth_headers: dict[str, str] = {}

@cached_property
def completions(self) -> AsyncCompletions:
Expand Down Expand Up @@ -566,12 +588,22 @@ def with_streaming_response(self) -> AsyncOpenAIWithStreamedResponse:
def qs(self) -> Querystring:
return Querystring(array_format="brackets")

async def refresh_auth_headers(self):
if self.bearer_token_provider:
bearer_token = await self.bearer_token_provider()
else:
bearer_token = self.api_key
self._auth_headers = {"Authorization": f"Bearer {bearer_token}"}

@override
async def _prepare_options(self, options: FinalRequestOptions) -> FinalRequestOptions:
await self.refresh_auth_headers()
return await super()._prepare_options(options)

@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"Bearer {api_key}"}

return self._auth_headers
@property
@override
def default_headers(self) -> dict[str, str | Omit]:
Expand All @@ -587,6 +619,7 @@ def copy(
self,
*,
api_key: str | None = None,
bearer_token_provider: Callable[[], Awaitable[str]] | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
Expand Down Expand Up @@ -624,6 +657,7 @@ def copy(
http_client = http_client or self._client
return self.__class__(
api_key=api_key or self.api_key,
bearer_token_provider = bearer_token_provider or self.bearer_token_provider,
organization=organization or self.organization,
project=project or self.project,
websocket_base_url=websocket_base_url or self.websocket_base_url,
Expand Down
2 changes: 1 addition & 1 deletion src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ async def _configure_realtime(self, model: str, extra_query: Query) -> tuple[htt
"api-version": self._api_version,
"deployment": self._azure_deployment or model,
}
if self.api_key != "<missing API key>":
if self.api_key and self.api_key != "<missing API key>":
auth_headers = {"api-key": self.api_key}
else:
token = await self._get_azure_ad_token()
Expand Down
2 changes: 2 additions & 0 deletions src/openai/resources/beta/realtime/realtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ async def __aenter__(self) -> AsyncRealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
await self.__client.refresh_auth_headers()
auth_headers = self.__client.auth_headers
if is_async_azure_client(self.__client):
url, auth_headers = await self.__client._configure_realtime(self.__model, extra_query)
Expand Down Expand Up @@ -540,6 +541,7 @@ def __enter__(self) -> RealtimeConnection:
raise OpenAIError("You need to install `openai[realtime]` to use this method") from exc

extra_query = self.__extra_query
self.__client.refresh_auth_headers()
auth_headers = self.__client.auth_headers
if is_azure_client(self.__client):
url, auth_headers = self.__client._configure_realtime(self.__model, extra_query)
Expand Down