Skip to content

Commit 276b556

Browse files
committed
add support for token provider
1 parent 862e363 commit 276b556

File tree

6 files changed

+185
-84
lines changed

6 files changed

+185
-84
lines changed

src/azure.ts

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import * as Errors from './error';
33
import { FinalRequestOptions } from './internal/request-options';
44
import { isObj, readEnv } from './internal/utils';
55
import { ClientOptions, OpenAI } from './client';
6-
import { buildHeaders, NullableHeaders } from './internal/headers';
76

87
/** API Client for interfacing with the Azure OpenAI API. */
98
export interface AzureClientOptions extends ClientOptions {
@@ -37,7 +36,6 @@ export interface AzureClientOptions extends ClientOptions {
3736

3837
/** API Client for interfacing with the Azure OpenAI API. */
3938
export class AzureOpenAI extends OpenAI {
40-
private _azureADTokenProvider: (() => Promise<string>) | undefined;
4139
deploymentName: string | undefined;
4240
apiVersion: string = '';
4341

@@ -90,9 +88,6 @@ export class AzureOpenAI extends OpenAI {
9088
);
9189
}
9290

93-
// define a sentinel value to avoid any typing issues
94-
apiKey ??= API_KEY_SENTINEL;
95-
9691
opts.defaultQuery = { ...opts.defaultQuery, 'api-version': apiVersion };
9792

9893
if (!baseURL) {
@@ -116,11 +111,12 @@ export class AzureOpenAI extends OpenAI {
116111
super({
117112
apiKey,
118113
baseURL,
114+
tokenProvider:
115+
!azureADTokenProvider ? undefined : async () => ({ token: await azureADTokenProvider() }),
119116
...opts,
120117
...(dangerouslyAllowBrowser !== undefined ? { dangerouslyAllowBrowser } : {}),
121118
});
122119

123-
this._azureADTokenProvider = azureADTokenProvider;
124120
this.apiVersion = apiVersion;
125121
this.deploymentName = deployment;
126122
}
@@ -140,47 +136,6 @@ export class AzureOpenAI extends OpenAI {
140136
}
141137
return super.buildRequest(options, props);
142138
}
143-
144-
async _getAzureADToken(): Promise<string | undefined> {
145-
if (typeof this._azureADTokenProvider === 'function') {
146-
const token = await this._azureADTokenProvider();
147-
if (!token || typeof token !== 'string') {
148-
throw new Errors.OpenAIError(
149-
`Expected 'azureADTokenProvider' argument to return a string but it returned ${token}`,
150-
);
151-
}
152-
return token;
153-
}
154-
return undefined;
155-
}
156-
157-
protected override authHeaders(opts: FinalRequestOptions): NullableHeaders | undefined {
158-
return;
159-
}
160-
161-
protected override async prepareOptions(opts: FinalRequestOptions): Promise<void> {
162-
opts.headers = buildHeaders([opts.headers]);
163-
164-
/**
165-
* The user should provide a bearer token provider if they want
166-
* to use Azure AD authentication. The user shouldn't set the
167-
* Authorization header manually because the header is overwritten
168-
* with the Azure AD token if a bearer token provider is provided.
169-
*/
170-
if (opts.headers.values.get('Authorization') || opts.headers.values.get('api-key')) {
171-
return super.prepareOptions(opts);
172-
}
173-
174-
const token = await this._getAzureADToken();
175-
if (token) {
176-
opts.headers.values.set('Authorization', `Bearer ${token}`);
177-
} else if (this.apiKey !== API_KEY_SENTINEL) {
178-
opts.headers.values.set('api-key', this.apiKey);
179-
} else {
180-
throw new Errors.OpenAIError('Unable to handle auth');
181-
}
182-
return super.prepareOptions(opts);
183-
}
184139
}
185140

186141
const _deployments_endpoints = new Set([
@@ -194,5 +149,3 @@ const _deployments_endpoints = new Set([
194149
'/batches',
195150
'/images/edits',
196151
]);
197-
198-
const API_KEY_SENTINEL = '<Missing Key>';

src/beta/realtime/websocket.ts

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,24 @@ export class OpenAIRealtimeWebSocket extends OpenAIRealtimeEmitter {
9494
}
9595
}
9696

97+
static async create(
98+
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setToken'>,
99+
props: { model: string; dangerouslyAllowBrowser?: boolean },
100+
): Promise<OpenAIRealtimeWebSocket> {
101+
await client._setToken();
102+
return new OpenAIRealtimeWebSocket(props, client);
103+
}
104+
97105
static async azure(
98-
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
106+
client: Pick<AzureOpenAI, '_setToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
99107
options: { deploymentName?: string; dangerouslyAllowBrowser?: boolean } = {},
100108
): Promise<OpenAIRealtimeWebSocket> {
101-
const token = await client._getAzureADToken();
109+
const isToken = await client._setToken();
102110
function onURL(url: URL) {
103-
if (client.apiKey !== '<Missing Key>') {
104-
url.searchParams.set('api-key', client.apiKey);
111+
if (isToken) {
112+
url.searchParams.set('Authorization', `Bearer ${client.apiKey}`);
105113
} else {
106-
if (token) {
107-
url.searchParams.set('Authorization', `Bearer ${token}`);
108-
} else {
109-
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
110-
}
114+
url.searchParams.set('api-key', client.apiKey);
111115
}
112116
}
113117
const deploymentName = options.deploymentName ?? client.deploymentName;

src/beta/realtime/ws.ts

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,16 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
5151
});
5252
}
5353

54+
static async create(
55+
client: Pick<OpenAI, 'apiKey' | 'baseURL' | '_setToken'>,
56+
props: { model: string; options?: WS.ClientOptions | undefined },
57+
): Promise<OpenAIRealtimeWS> {
58+
await client._setToken();
59+
return new OpenAIRealtimeWS(props, client);
60+
}
61+
5462
static async azure(
55-
client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
63+
client: Pick<AzureOpenAI, '_setToken' | 'apiVersion' | 'apiKey' | 'baseURL' | 'deploymentName'>,
5664
options: { deploymentName?: string; options?: WS.ClientOptions | undefined } = {},
5765
): Promise<OpenAIRealtimeWS> {
5866
const deploymentName = options.deploymentName ?? client.deploymentName;
@@ -82,15 +90,11 @@ export class OpenAIRealtimeWS extends OpenAIRealtimeEmitter {
8290
}
8391
}
8492

85-
async function getAzureHeaders(client: Pick<AzureOpenAI, '_getAzureADToken' | 'apiKey'>) {
86-
if (client.apiKey !== '<Missing Key>') {
87-
return { 'api-key': client.apiKey };
93+
async function getAzureHeaders(client: Pick<AzureOpenAI, '_setToken' | 'apiKey'>) {
94+
const isToken = await client._setToken();
95+
if (isToken) {
96+
return { Authorization: `Bearer ${isToken}` };
8897
} else {
89-
const token = await client._getAzureADToken();
90-
if (token) {
91-
return { Authorization: `Bearer ${token}` };
92-
} else {
93-
throw new Error('AzureOpenAI is not instantiated correctly. No API key or token provided.');
94-
}
98+
return { 'api-key': client.apiKey };
9599
}
96100
}

src/client.ts

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,12 +181,20 @@ import {
181181
} from './internal/utils/log';
182182
import { isEmptyObj } from './internal/utils/values';
183183

184+
export interface AccessToken {
185+
token: string;
186+
}
187+
export type TokenProvider = () => Promise<AccessToken>;
188+
184189
export interface ClientOptions {
185190
/**
186191
* Defaults to process.env['OPENAI_API_KEY'].
187192
*/
188193
apiKey?: string | undefined;
189-
194+
/**
195+
* A function that returns a token to use for authentication.
196+
*/
197+
tokenProvider?: TokenProvider | undefined;
190198
/**
191199
* Defaults to process.env['OPENAI_ORG_ID'].
192200
*/
@@ -297,6 +305,7 @@ export class OpenAI {
297305
#encoder: Opts.RequestEncoder;
298306
protected idempotencyHeader?: string;
299307
private _options: ClientOptions;
308+
private _tokenProvider: TokenProvider | undefined;
300309

301310
/**
302311
* API Client for interfacing with the OpenAI API.
@@ -320,11 +329,18 @@ export class OpenAI {
320329
organization = readEnv('OPENAI_ORG_ID') ?? null,
321330
project = readEnv('OPENAI_PROJECT_ID') ?? null,
322331
webhookSecret = readEnv('OPENAI_WEBHOOK_SECRET') ?? null,
332+
tokenProvider,
323333
...opts
324334
}: ClientOptions = {}) {
325-
if (apiKey === undefined) {
335+
if (apiKey === undefined && !tokenProvider) {
336+
throw new Errors.OpenAIError(
337+
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
338+
);
339+
}
340+
341+
if (tokenProvider && apiKey) {
326342
throw new Errors.OpenAIError(
327-
"The OPENAI_API_KEY environment variable is missing or empty; either provide it, or instantiate the OpenAI client with an apiKey option, like new OpenAI({ apiKey: 'My API Key' }).",
343+
'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.',
328344
);
329345
}
330346

@@ -333,6 +349,7 @@ export class OpenAI {
333349
organization,
334350
project,
335351
webhookSecret,
352+
tokenProvider,
336353
...opts,
337354
baseURL: baseURL || `https://api.openai.com/v1`,
338355
};
@@ -360,7 +377,8 @@ export class OpenAI {
360377

361378
this._options = options;
362379

363-
this.apiKey = apiKey;
380+
this.apiKey = apiKey ?? 'Missing Key';
381+
this._tokenProvider = tokenProvider;
364382
this.organization = organization;
365383
this.project = project;
366384
this.webhookSecret = webhookSecret;
@@ -380,6 +398,7 @@ export class OpenAI {
380398
fetch: this.fetch,
381399
fetchOptions: this.fetchOptions,
382400
apiKey: this.apiKey,
401+
tokenProvider: this._tokenProvider,
383402
organization: this.organization,
384403
project: this.project,
385404
webhookSecret: this.webhookSecret,
@@ -427,6 +446,31 @@ export class OpenAI {
427446
return Errors.APIError.generate(status, error, message, headers);
428447
}
429448

449+
async _setToken(): Promise<boolean> {
450+
if (typeof this._tokenProvider === 'function') {
451+
try {
452+
const token = await this._tokenProvider();
453+
if (!token || typeof token.token !== 'string') {
454+
throw new Errors.OpenAIError(
455+
`Expected 'tokenProvider' argument to return a string but it returned ${token}`,
456+
);
457+
}
458+
this.apiKey = token.token;
459+
return true;
460+
} catch (err: any) {
461+
if (err instanceof Errors.OpenAIError) {
462+
throw err;
463+
}
464+
throw new Errors.OpenAIError(
465+
`Failed to get token from 'tokenProvider' function: ${err.message}`,
466+
// @ts-ignore
467+
{ cause: err },
468+
);
469+
}
470+
}
471+
return false;
472+
}
473+
430474
buildURL(
431475
path: string,
432476
query: Record<string, unknown> | null | undefined,
@@ -453,7 +497,9 @@ export class OpenAI {
453497
/**
454498
* Used as a callback for mutating the given `FinalRequestOptions` object.
455499
*/
456-
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {}
500+
protected async prepareOptions(options: FinalRequestOptions): Promise<void> {
501+
await this._setToken();
502+
}
457503

458504
/**
459505
* Used as a callback for mutating the given `RequestInit` object.

tests/index.test.ts

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -719,4 +719,96 @@ describe('retries', () => {
719719
).toEqual(JSON.stringify({ a: 1 }));
720720
expect(count).toEqual(3);
721721
});
722+
723+
describe('auth', () => {
724+
test('apiKey', () => {
725+
const client = new OpenAI({
726+
baseURL: 'http://localhost:5000/',
727+
apiKey: 'My API Key',
728+
});
729+
const { req } = client.buildRequest({ path: '/foo', method: 'get' });
730+
expect(req.headers.get('authorization')).toEqual('Bearer My API Key');
731+
});
732+
733+
test('token', async () => {
734+
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
735+
return new Response(JSON.stringify({}), { headers: headers ?? [] });
736+
};
737+
const client = new OpenAI({
738+
baseURL: 'http://localhost:5000/',
739+
tokenProvider: async () => ({ token: 'my token' }),
740+
fetch: testFetch,
741+
});
742+
expect(
743+
(await client.request({ method: 'post', path: 'https://example.com' }).asResponse()).headers.get(
744+
'authorization',
745+
),
746+
).toEqual('Bearer my token');
747+
});
748+
749+
test('token is refreshed', async () => {
750+
let fail = true;
751+
const testFetch = async (url: any, { headers }: RequestInit = {}): Promise<Response> => {
752+
if (fail) {
753+
fail = false;
754+
return new Response(undefined, {
755+
status: 429,
756+
headers: {
757+
'Retry-After': '0.1',
758+
},
759+
});
760+
}
761+
return new Response(JSON.stringify({}), {
762+
headers: headers ?? [],
763+
});
764+
};
765+
let counter = 0;
766+
async function tokenProvider() {
767+
return { token: `token-${counter++}` };
768+
}
769+
const client = new OpenAI({
770+
baseURL: 'http://localhost:5000/',
771+
tokenProvider,
772+
fetch: testFetch,
773+
});
774+
expect(
775+
(
776+
await client.chat.completions
777+
.create({
778+
model: '',
779+
messages: [{ role: 'system', content: 'Hello' }],
780+
})
781+
.asResponse()
782+
).headers.get('authorization'),
783+
).toEqual('Bearer token-1');
784+
});
785+
786+
test('mutual exclusive', () => {
787+
try {
788+
new OpenAI({
789+
baseURL: 'http://localhost:5000/',
790+
tokenProvider: async () => ({ token: 'my token' }),
791+
apiKey: 'my api key',
792+
});
793+
} catch (error: any) {
794+
expect(error).toBeInstanceOf(Error);
795+
expect(error.message).toEqual(
796+
'The `apiKey` and `tokenProvider` arguments are mutually exclusive; only one can be passed at a time.',
797+
);
798+
}
799+
});
800+
801+
test('at least one', () => {
802+
try {
803+
new OpenAI({
804+
baseURL: 'http://localhost:5000/',
805+
});
806+
} catch (error: any) {
807+
expect(error).toBeInstanceOf(Error);
808+
expect(error.message).toEqual(
809+
'Missing credentials. Please pass one of `apiKey` and `tokenProvider`, or set the `OPENAI_API_KEY` environment variable.',
810+
);
811+
}
812+
});
813+
});
722814
});

0 commit comments

Comments
 (0)