Skip to content

Commit f514473

Browse files
committed
add support for token provider
1 parent 946a3a2 commit f514473

File tree

6 files changed

+185
-84
lines changed

6 files changed

+185
-84
lines changed

src/azure.ts

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import * as Errors from './error';
33
import { FinalRequestOptions } from './internal/request-options';
44
import { isObj, readEnv } from './internal/utils';
55
import { ClientOptions, OpenAI } from './client';
6-
import { buildHeaders, NullableHeaders } from './internal/headers';
76

87
/** API Client for interfacing with the Azure OpenAI API. */
98
export interface AzureClientOptions extends ClientOptions {
@@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions {
3736

3837
/** API Client for interfacing with the Azure OpenAI API. */
3938
export class AzureOpenAI extends OpenAI {
40-
private _azureADTokenProvider: (() => Promise<string>) | undefined;
4139
deploymentName: string | undefined;
4240
apiVersion: string = '';
4341

@@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI {
9088
);
9189
}
9290

93-
// define a sentinel value to avoid any typing issues
94-
apiKey ??= API_KEY_SENTINEL;
95-
9691
opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion };
9792

9893
if (!baseURL) {
@@ -116,11 +111,12 @@ export class AzureOpenAI extends OpenAI {
116111
super({
117112
apiKey,
118113
baseURL,
114+
tokenProvider:
115+
!azureADTokenProvider ? undefined : async () => ({ token: await azureADTokenProvider() }),
119116
...opts,
120117
...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}),
121118
});
122119

123-
this._azureADTokenProvider = azureADTokenProvider;
124120
this.apiVersion = apiVersion;
125121
this.deploymentName = deployment;
126122
}
@@ -140,47 +136,6 @@ export class AzureOpenAI extends OpenAI {
140136
}
141137
return super.buildRequest(options, props);
142138
}
143-
144-
async _getAzureADToken(): Promise<string | undefined> {
145-
if (typeof this._azureADTokenProvider === 'function') {
146-
const token = await this._azureADTokenProvider();
147-
if (!token || typeof token !== 'string') {
148-
throw new Errors.OpenAIError(
149-
`Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`,
150-
);
151-
}
152-
return token;
153-
}
154-
return undefined;
155-
}
156-
157-
protected override async authHeaders(opts: FinalRequestOptions): Promise<NullableHeaders | undefined> {
158-
return;
159-
}
160-
161-
protected override async prepareOptions(opts: FinalRequestOptions): Promise<void> {
162-
opts.headers = buildHeaders([opts.headers]);
163-
164-
/**
165-
* The user should provide a bearer token provider if they want
166-
* to use Azure AD authentication. The user shouldn't set the
167-
* Authorization header manually because the header is overwritten
168-
* with the Azure AD token if a bearer token provider is provided.
169-
*/
170-
if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) {
171-
return super.prepareOptions(opts);
172-
}
173-
174-
const token = await this._getAzureADToken();
175-
if (token) {
176-
opts.headers.values.set('Authorization', `Bearer ${token}`);
177-
} else if (this.apiKey !== API_KEY_SENTINEL) {
178-
opts.headers.values.set('api-key', this.apiKey);
179-
} else {
180-
throw new Errors.OpenAIError('Unable to handle auth');
181-
}
182-
return super.prepareOptions(opts);
183-
}
184139
}
185140

186141
const _deployments_endpoints = new Set([
@@ -194,5 +149,3 @@ const _deployments_endpoints = new Set([
194149
'/batches',
195150
'/images/edits',
196151
]);
197-
198-
const API_KEY_SENTINEL = '<Missing Key>';

src/beta/realtime/websocket.ts

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,24 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
9494
}
9595
}
9696

97+
static async create(
98+
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setToken'>,
99+
props: { model: string; dangerouslyAllowBrowser?: boolean },
100+
): Promise<OpenAIRealtimeWebSocket> {
101+
await client._setToken();
102+
return new OpenAIRealtimeWebSocket(props, client);
103+
}
104+
97105
static async azure(
98-
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
106+
client: Pick<AzureOpenAI, '_setToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
99107
options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {},
100108
): Promise<OpenAIRealtimeWebSocket> {
101-
const token = await client._getAzureADToken();
109+
const isToken = await client._setToken();
102110
function onURL(url: URL) {
103-
if (client.apiKey !== '<Missing Key>') {
104-
url.searchParams.set('api-key', client.apiKey);
111+
if (isToken) {
112+
url.searchParams.set('Authorization', `Bearer ${client.apiKey}`);
105113
} else {
106-
if (token) {
107-
url.searchParams.set('Authorization', `Bearer ${token}`);
108-
} else {
109-
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
110-
}
114+
url.searchParams.set('api-key', client.apiKey);
111115
}
112116
}
113117
const deploymentName = options.deploymentName ?? client.deploymentName;

src/beta/realtime/ws.ts

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,16 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
5151
});
5252
}
5353

54+
static async create(
55+
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setToken'>,
56+
props: { model: string; options?: WS.ClientOptions | undefined },
57+
): Promise<OpenAIRealtimeWS> {
58+
await client._setToken();
59+
return new OpenAIRealtimeWS(props, client);
60+
}
61+
5462
static async azure(
55-
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
63+
client: Pick<AzureOpenAI, '_setToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
5664
options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
5765
): Promise<OpenAIRealtimeWS> {
5866
const deploymentName = options.deploymentName ?? client.deploymentName;
@@ -82,15 +90,11 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
8290
}
8391
}
8492

85-
async function getAzureHeaders(client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiKey'>) {
86-
if (client.apiKey !== '<Missing Key>') {
87-
return { 'api-key': client.apiKey };
93+
async function getAzureHeaders(client: Pick<AzureOpenAI, '_setToken' | 'apiKey'>) {
94+
const isToken = await client._setToken();
95+
if (isToken) {
96+
return { Authorization: `Bearer ${isToken}` };
8897
} else {
89-
const token = await client._getAzureADToken();
90-
if (token) {
91-
return { Authorization: `Bearer ${token}` };
92-
} else {
93-
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
94-
}
98+
return { 'api-key': client.apiKey };
9599
}
96100
}

src/client.ts

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,20 @@ import {
191191
} from './internal/utils/log';
192192
import { isEmptyObj } from './internal/utils/values';
193193

194+
export interface AccessToken {
195+
token: string;
196+
}
197+
export type TokenProvider = () => Promise<AccessToken>;
198+
194199
export interface ClientOptions {
195200
/**
196201
* Defaults to process.env['OPENAI_API_KEY'].
197202
*/
198203
apiKey?: string | undefined;
199-
204+
/**
205+
* A function that returns a token to use for authentication.
206+
*/
207+
tokenProvider?: TokenProvider | undefined;
200208
/**
201209
* Defaults to process.env['OPENAI_ORG_ID'].
202210
*/
@@ -307,6 +315,7 @@ export class OpenAI {
307315
#encoder: Opts.RequestEncoder;
308316
protected idempotencyHeader?: string;
309317
private _options: ClientOptions;
318+
private _tokenProvider: TokenProvider | undefined;
310319

311320
/**
312321
* API Client for interfacing with the OpenAI API.
@@ -330,11 +339,18 @@ export class OpenAI {
330339
organization = readEnv('OPENAI_ORG_ID') ?? null,
331340
project = readEnv('OPENAI_PROJECT_ID') ?? null,
332341
webhookSecret = readEnv('OPENAI_WEBHOOK_SECRET') ?? null,
342+
tokenProvider,
333343
...opts
334344
}: ClientOptions = {}) {
335-
if (apiKey === undefined) {
345+
if (apiKey === undefined && !tokenProvider) {
346+
throw new Errors.OpenAIError(
347+
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
348+
);
349+
}
350+
351+
if (tokenProvider && apiKey) {
336352
throw new Errors.OpenAIError(
337-
"The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).",
353+
'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.',
338354
);
339355
}
340356

@@ -343,6 +359,7 @@ export class OpenAI {
343359
organization,
344360
project,
345361
webhookSecret,
362+
tokenProvider,
346363
...opts,
347364
baseURL: baseURL || `https://api.openai.com/v1`,
348365
};
@@ -370,7 +387,8 @@ export class OpenAI {
370387

371388
this._options = options;
372389

373-
this.apiKey = apiKey;
390+
this.apiKey = apiKey ?? 'Missing Key';
391+
this._tokenProvider = tokenProvider;
374392
this.organization = organization;
375393
this.project = project;
376394
this.webhookSecret = webhookSecret;
@@ -390,6 +408,7 @@ export class OpenAI {
390408
fetch: this.fetch,
391409
fetchOptions: this.fetchOptions,
392410
apiKey: this.apiKey,
411+
tokenProvider: this._tokenProvider,
393412
organization: this.organization,
394413
project: this.project,
395414
webhookSecret: this.webhookSecret,
@@ -438,6 +457,31 @@ export class OpenAI {
438457
return Errors.APIError.generate(status, error, message, headers);
439458
}
440459

460+
async _setToken(): Promise<boolean> {
461+
if (typeof this._tokenProvider === 'function') {
462+
try {
463+
const token = await this._tokenProvider();
464+
if (!token || typeof token.token !== 'string') {
465+
throw new Errors.OpenAIError(
466+
`Expected 'tokenProvider' argument to return a string but it returned ${token}`,
467+
);
468+
}
469+
this.apiKey = token.token;
470+
return true;
471+
} catch (err: any) {
472+
if (err instanceof Errors.OpenAIError) {
473+
throw err;
474+
}
475+
throw new Errors.OpenAIError(
476+
`Failed to get token from 'tokenProvider' function: ${err.message}`,
477+
// @ts-ignore
478+
{ cause: err },
479+
);
480+
}
481+
}
482+
return false;
483+
}
484+
441485
buildURL(
442486
path: string,
443487
query: Record<string, unknown> | null | undefined,
@@ -464,7 +508,9 @@ export class OpenAI {
464508
/**
465509
* Used as a callback for mutating the given `FinalRequestOptions` object.
466510
*/
467-
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {}
511+
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {
512+
await this._setToken();
513+
}
468514

469515
/**
470516
* Used as a callback for mutating the given `RequestInit` object.

tests/index.test.ts

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,4 +719,96 @@ describe('retries', () => {
719719
).toEqual(JSON.stringify({ a: 1 }));
720720
expect(count).toEqual(3);
721721
});
722+
723+
describe('auth', () => {
724+
test('apiKey', async () => {
725+
const client = new OpenAI({
726+
baseURL: 'http://localhost:5000/',
727+
apiKey: 'My API Key',
728+
});
729+
const { req } = await client.buildRequest({ path: '/foo', method: 'get' });
730+
expect(req.headers.get('authorization')).toEqual('Bearer My API Key');
731+
});
732+
733+
test('token', async () => {
734+
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
735+
return new Response(JSON.stringify({}), { headers: headers ?? [] });
736+
};
737+
const client = new OpenAI({
738+
baseURL: 'http://localhost:5000/',
739+
tokenProvider: async () => ({ token: 'my token' }),
740+
fetch: testFetch,
741+
});
742+
expect(
743+
(await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get(
744+
'authorization',
745+
),
746+
).toEqual('Bearer my token');
747+
});
748+
749+
test('token is refreshed', async () => {
750+
let fail = true;
751+
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
752+
if (fail) {
753+
fail = false;
754+
return new Response(undefined, {
755+
status: 429,
756+
headers: {
757+
'Retry-After': '0.1',
758+
},
759+
});
760+
}
761+
return new Response(JSON.stringify({}), {
762+
headers: headers ?? [],
763+
});
764+
};
765+
let counter = 0;
766+
async function tokenProvider() {
767+
return { token: `token-${counter++}` };
768+
}
769+
const client = new OpenAI({
770+
baseURL: 'http://localhost:5000/',
771+
tokenProvider,
772+
fetch: testFetch,
773+
});
774+
expect(
775+
(
776+
await client.chat.completions
777+
.create({
778+
model: '',
779+
messages: [{ role: 'system', content: 'Hello' }],
780+
})
781+
.asResponse()
782+
).headers.get('authorization'),
783+
).toEqual('Bearer token-1');
784+
});
785+
786+
test('mutual exclusive', () => {
787+
try {
788+
new OpenAI({
789+
baseURL: 'http://localhost:5000/',
790+
tokenProvider: async () => ({ token: 'my token' }),
791+
apiKey: 'my api key',
792+
});
793+
} catch (error: any) {
794+
expect(error).toBeInstanceOf(Error);
795+
expect(error.message).toEqual(
796+
'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.',
797+
);
798+
}
799+
});
800+
801+
test('at least one', () => {
802+
try {
803+
new OpenAI({
804+
baseURL: 'http://localhost:5000/',
805+
});
806+
} catch (error: any) {
807+
expect(error).toBeInstanceOf(Error);
808+
expect(error.message).toEqual(
809+
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
810+
);
811+
}
812+
});
813+
});
722814
});

0 commit comments

Comments
 (0)