Skip to content

Commit fe8bbaa

Browse files
authored
[Azure] Refresh AAD token on retry (#1003)
* [Azure] Refresh AAD token on retry * add context
1 parent 39731a6 commit fe8bbaa

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

src/index.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,13 @@ export class AzureOpenAI extends OpenAI {
485485
}
486486

487487
protected override async prepareOptions(opts: Core.FinalRequestOptions<unknown>): Promise<void> {
488-
if (opts.headers?.['Authorization'] || opts.headers?.['api-key']) {
488+
/**
489+
* The user should provide a bearer token provider if they want
490+
* to use Azure AD authentication. The user shouldn't set the
491+
* Authorization header manually because the header is overwritten
492+
* with the Azure AD token if a bearer token provider is provided.
493+
*/
494+
if (opts.headers?.['api-key']) {
489495
return super.prepareOptions(opts);
490496
}
491497
const token = await this._getAzureADToken();

tests/lib/azure.test.ts

+37
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,43 @@ describe('instantiate azure client', () => {
254254
/The `apiKey` and `azureADTokenProvider` arguments are mutually exclusive; only one can be passed at a time./,
255255
);
256256
});
257+
258+
test('AAD token is refreshed', async () => {
259+
let fail = true;
260+
const testFetch = async (url: RequestInfo, req: RequestInit | undefined): Promise<Response> => {
261+
if (fail) {
262+
fail = false;
263+
return new Response(undefined, {
264+
status: 429,
265+
headers: {
266+
'Retry-After': '0.1',
267+
},
268+
});
269+
}
270+
return new Response(
271+
JSON.stringify({ auth: (req?.headers as Record<string, string>)['authorization'] }),
272+
{ headers: { 'content-type': 'application/json' } },
273+
);
274+
};
275+
let counter = 0;
276+
async function azureADTokenProvider() {
277+
return `token-${counter++}`;
278+
}
279+
const client = new AzureOpenAI({
280+
baseURL: 'http://localhost:5000/',
281+
azureADTokenProvider,
282+
apiVersion,
283+
fetch: testFetch,
284+
});
285+
expect(
286+
await client.chat.completions.create({
287+
model,
288+
messages: [{ role: 'system', content: 'Hello' }],
289+
}),
290+
).toStrictEqual({
291+
auth: 'Bearer token-1',
292+
});
293+
});
257294
});
258295

259296
test('with endpoint', () => {

0 commit comments

Comments
 (0)