Skip to content

Enable custom auth for OpenAI clients (prototype) #2427

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@

from .lib import azure as _azure, pydantic_function_tool as pydantic_function_tool
from .version import VERSION as VERSION
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI
from .lib.azure import AzureOpenAI as AzureOpenAI, AsyncAzureOpenAI as AsyncAzureOpenAI, AzureAuth as AzureAuth
from .lib._old_api import *
from .lib.streaming import (
AssistantEventHandler as AssistantEventHandler,
Expand Down
14 changes: 14 additions & 0 deletions src/openai/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,7 @@ def __init__(
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
auth: httpx.Auth | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
_strict_response_validation: bool,
Expand Down Expand Up @@ -856,6 +857,12 @@ def __init__(
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
)
self._custom_auth = auth

@property
@override
def custom_auth(self) -> httpx.Auth | None:
return self._custom_auth

def is_closed(self) -> bool:
return self._client.is_closed
Expand Down Expand Up @@ -1343,6 +1350,7 @@ def __init__(
max_retries: int = DEFAULT_MAX_RETRIES,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
auth: httpx.Auth | None = None,
custom_headers: Mapping[str, str] | None = None,
custom_query: Mapping[str, object] | None = None,
) -> None:
Expand Down Expand Up @@ -1379,7 +1387,13 @@ def __init__(
# cast to a valid type because mypy doesn't understand our type narrowing
timeout=cast(Timeout, timeout),
)
self._custom_auth = auth

@property
@override
def custom_auth(self) -> httpx.Auth | None:
return self._custom_auth

def is_closed(self) -> bool:
return self._client.is_closed

Expand Down
8 changes: 8 additions & 0 deletions src/openai/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
# We provide a `DefaultHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
http_client: httpx.Client | None = None,
auth: httpx.Auth | None = None,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
max_retries=max_retries,
timeout=timeout,
http_client=http_client,
auth = auth,
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
Expand Down Expand Up @@ -292,6 +294,7 @@ def copy(
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.Client | None = None,
auth: httpx.Auth | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -329,6 +332,7 @@ def copy(
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
auth=auth or self.custom_auth,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
Expand Down Expand Up @@ -404,6 +408,7 @@ def __init__(
# We provide a `DefaultAsyncHttpxClient` class that you can pass to retain the default values we use for `limits`, `timeout` & `follow_redirects`.
# See the [httpx documentation](https://www.python-httpx.org/api/#asyncclient) for more details.
http_client: httpx.AsyncClient | None = None,
auth: httpx.Auth | None = None,
# Enable or disable schema validation for data returned by the API.
# When enabled an error APIResponseValidationError is raised
# if the API responds with invalid data for the expected schema.
Expand Down Expand Up @@ -450,6 +455,7 @@ def __init__(
max_retries=max_retries,
timeout=timeout,
http_client=http_client,
auth=auth,
custom_headers=default_headers,
custom_query=default_query,
_strict_response_validation=_strict_response_validation,
Expand Down Expand Up @@ -593,6 +599,7 @@ def copy(
base_url: str | httpx.URL | None = None,
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
http_client: httpx.AsyncClient | None = None,
auth: httpx.Auth | None = None,
max_retries: int | NotGiven = NOT_GIVEN,
default_headers: Mapping[str, str] | None = None,
set_default_headers: Mapping[str, str] | None = None,
Expand Down Expand Up @@ -630,6 +637,7 @@ def copy(
base_url=base_url or self.base_url,
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
http_client=http_client,
auth=auth or self.custom_auth,
max_retries=max_retries if is_given(max_retries) else self.max_retries,
default_headers=headers,
default_query=params,
Expand Down
15 changes: 14 additions & 1 deletion src/openai/lib/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
import inspect
from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, cast, overload
from typing import Any, Union, Mapping, TypeVar, Callable, Generator, Awaitable, cast, overload
from typing_extensions import Self, override

import httpx
Expand Down Expand Up @@ -85,6 +85,15 @@ def _prepare_url(self, url: str) -> httpx.URL:

return super()._prepare_url(url)

class AzureAuth(httpx.Auth):
def __init__(self, credential: Any, *, scope: str = 'https://cognitiveservices.azure.com/.default'):
self.credential = credential
self.scope = scope

@override
def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Response, None]:
request.headers['Authorization'] = 'Bearer ' + self.credential.get_token(self.scope).token
yield request

class AzureOpenAI(BaseAzureClient[httpx.Client, Stream[Any]], OpenAI):
@overload
Expand Down Expand Up @@ -254,6 +263,7 @@ def copy(
self,
*,
api_key: str | None = None,
auth: httpx.Auth | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
Expand Down Expand Up @@ -426,6 +436,7 @@ def __init__(
azure_deployment: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
auth: httpx.Auth | None = None,
azure_ad_token: str | None = None,
azure_ad_token_provider: AsyncAzureADTokenProvider | None = None,
organization: str | None = None,
Expand Down Expand Up @@ -528,6 +539,7 @@ def copy(
self,
*,
api_key: str | None = None,
auth: httpx.Auth | None = None,
organization: str | None = None,
project: str | None = None,
websocket_base_url: str | httpx.URL | None = None,
Expand All @@ -549,6 +561,7 @@ def copy(
"""
return super().copy(
api_key=api_key,
auth=auth,
organization=organization,
project=project,
websocket_base_url=websocket_base_url,
Expand Down