From db4bb8e179d6f573b49331d55e97ab4930e19cb9 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Mon, 8 Jul 2024 12:40:52 -0700 Subject: [PATCH] update authorization header with refreshed token --- src/openai/lib/azure.py | 6 ++---- tests/lib/test_azure.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index cbe57b7b98..b13840afe8 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -287,8 +287,7 @@ def _prepare_options(self, options: FinalRequestOptions) -> None: azure_ad_token = self._get_azure_ad_token() if azure_ad_token is not None: - if headers.get("Authorization") is None: - headers["Authorization"] = f"Bearer {azure_ad_token}" + headers["Authorization"] = f"Bearer {azure_ad_token}" elif self.api_key is not API_KEY_SENTINEL: if headers.get("api-key") is None: headers["api-key"] = self.api_key @@ -530,8 +529,7 @@ async def _prepare_options(self, options: FinalRequestOptions) -> None: azure_ad_token = await self._get_azure_ad_token() if azure_ad_token is not None: - if headers.get("Authorization") is None: - headers["Authorization"] = f"Bearer {azure_ad_token}" + headers["Authorization"] = f"Bearer {azure_ad_token}" elif self.api_key is not API_KEY_SENTINEL: if headers.get("api-key") is None: headers["api-key"] = self.api_key diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 9360b2925a..59cff6c9c5 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -64,3 +64,42 @@ def test_client_copying_override_options(client: Client) -> None: api_version="2022-05-01", ) assert copied._custom_query == {"api-version": "2022-05-01"} + + +def test_client_token_provider_refresh_sync() -> None: + options = FinalRequestOptions.construct( + method="post", + url="/chat/completions", + json_data={"model": "my-deployment-model"}, + headers={"Authorization": "Bearer expired"} + ) + + sync_client = AzureOpenAI( + api_version="2024-02-01", + azure_ad_token_provider=lambda: "valid", + azure_endpoint="https://example-resource.azure.openai.com", + ) + + sync_client._prepare_options(options) + token = options.headers["Authorization"] + assert token == "Bearer valid" + + +@pytest.mark.asyncio +async def test_client_token_provider_refresh_async() -> None: + options = FinalRequestOptions.construct( + method="post", + url="/chat/completions", + json_data={"model": "my-deployment-model"}, + headers={"Authorization": "Bearer expired"} + ) + + async_client = AsyncAzureOpenAI( + api_version="2024-02-01", + azure_ad_token_provider=lambda: "valid", + azure_endpoint="https://example-resource.azure.openai.com", + ) + + await async_client._prepare_options(options) + token = options.headers["Authorization"] + assert token == "Bearer valid"