From 065cc601496515f2185eb31b2b028f6821f5b6b3 Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 4 Jan 2024 11:29:28 -0800 Subject: [PATCH 1/3] honor retry-after-ms for azure clients --- src/openai/lib/azure.py | 32 +++++++++++++++++++++++++++++++- tests/lib/test_azure.py | 39 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/openai/lib/azure.py b/src/openai/lib/azure.py index 27bebd8cab..310c337bf2 100644 --- a/src/openai/lib/azure.py +++ b/src/openai/lib/azure.py @@ -2,7 +2,7 @@ import os import inspect -from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload +from typing import Any, Union, Mapping, TypeVar, Callable, Awaitable, overload, Optional from typing_extensions import Self, override import httpx @@ -274,6 +274,21 @@ def _get_azure_ad_token(self) -> str | None: return None + @override + def _calculate_retry_timeout( + self, + remaining_retries: int, + options: FinalRequestOptions, + response_headers: Optional[httpx.Headers] = None, + ) -> float: + try: + if response_headers: + return float(response_headers["retry-after-ms"]) / 1000 + except (KeyError, ValueError): + pass + + return super()._calculate_retry_timeout(remaining_retries, options, response_headers) + @override def _prepare_options(self, options: FinalRequestOptions) -> None: headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {} @@ -509,6 +524,21 @@ async def _get_azure_ad_token(self) -> str | None: return None + @override + def _calculate_retry_timeout( + self, + remaining_retries: int, + options: FinalRequestOptions, + response_headers: Optional[httpx.Headers] = None, + ) -> float: + try: + if response_headers: + return float(response_headers["retry-after-ms"]) / 1000 + except (KeyError, ValueError): + pass + + return super()._calculate_retry_timeout(remaining_retries, options, response_headers) + @override async def _prepare_options(self, options: FinalRequestOptions) -> None: headers: dict[str, str | Omit] = {**options.headers} if is_given(options.headers) else {} diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 9360b2925a..b566ad8029 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -2,6 +2,7 @@ from typing_extensions import Literal import pytest +import httpx from openai._models import FinalRequestOptions from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI @@ -64,3 +65,41 @@ def test_client_copying_override_options(client: Client) -> None: api_version="2022-05-01", ) assert copied._custom_query == {"api-version": "2022-05-01"} + + +@pytest.mark.parametrize( + "client,headers,timeout", + [ + (sync_client, {"retry-after-ms": "2000"}, 2.0), + (sync_client, {"retry-after-ms": "2", "retry-after": "1"}, 0.002), + (async_client, {"retry-after-ms": "2000"}, 2.0), + (async_client, {"retry-after-ms": "2", "retry-after": "1"}, 0.002), + ], +) +def test_parse_retry_after_ms_header(client: Client, headers: httpx.Headers, timeout: float) -> None: + headers = httpx.Headers(headers) + options = FinalRequestOptions(method="post", url="/completions") + retry_timeout = client._calculate_retry_timeout( + remaining_retries=2, + options=options, + response_headers=headers + ) + assert retry_timeout == timeout + + +@pytest.mark.parametrize( + "client,headers", + [ + (sync_client, {}), + (async_client, {}), + ], +) +def test_no_retry_after_header(client: Client, headers: httpx.Headers) -> None: + headers = httpx.Headers(headers) + options = FinalRequestOptions(method="post", url="/completions") + retry_timeout = client._calculate_retry_timeout( + remaining_retries=2, + options=options, + response_headers=headers + ) + assert retry_timeout From d77426bdf8cbf453da502632fd2c80df65cd0ddc Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 4 Jan 2024 13:40:02 -0800 Subject: [PATCH 2/3] few more tests --- tests/lib/test_azure.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index b566ad8029..964e58c711 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -72,8 +72,10 @@ def test_client_copying_override_options(client: Client) -> None: [ (sync_client, {"retry-after-ms": "2000"}, 2.0), (sync_client, {"retry-after-ms": "2", "retry-after": "1"}, 0.002), + (sync_client, {"Retry-After-Ms": "2", "Retry-After": "1"}, 0.002), (async_client, {"retry-after-ms": "2000"}, 2.0), (async_client, {"retry-after-ms": "2", "retry-after": "1"}, 0.002), + (async_client, {"Retry-After-Ms": "2", "Retry-After": "1"}, 0.002), ], ) def test_parse_retry_after_ms_header(client: Client, headers: httpx.Headers, timeout: float) -> None: @@ -92,6 +94,8 @@ def test_parse_retry_after_ms_header(client: Client, headers: httpx.Headers, tim [ (sync_client, {}), (async_client, {}), + (sync_client, None), + (async_client, None), ], ) def test_no_retry_after_header(client: Client, headers: httpx.Headers) -> None: @@ -102,4 +106,23 @@ def test_no_retry_after_header(client: Client, headers: httpx.Headers) -> None: options=options, response_headers=headers ) - assert retry_timeout + assert retry_timeout # uses default retry implementation + + + +@pytest.mark.parametrize( + "client,headers", + [ + (sync_client, {"retry-after-ms": "invalid"}), + (sync_client, {"retry-after-ms": "invalid"}), + ], +) +def test_invalid_retry_after_header(client: Client, headers: httpx.Headers) -> None: + headers = httpx.Headers(headers) + options = FinalRequestOptions(method="post", url="/completions") + retry_timeout = client._calculate_retry_timeout( + remaining_retries=2, + options=options, + response_headers=headers + ) + assert retry_timeout # uses default retry implementation From afcb1b5ade93d76c9ee1fe5589661633942cb3af Mon Sep 17 00:00:00 2001 From: Krista Pratico Date: Thu, 4 Jan 2024 13:53:05 -0800 Subject: [PATCH 3/3] fix test --- tests/lib/test_azure.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/lib/test_azure.py b/tests/lib/test_azure.py index 964e58c711..0da7447916 100644 --- a/tests/lib/test_azure.py +++ b/tests/lib/test_azure.py @@ -114,7 +114,7 @@ def test_no_retry_after_header(client: Client, headers: httpx.Headers) -> None: "client,headers", [ (sync_client, {"retry-after-ms": "invalid"}), - (sync_client, {"retry-after-ms": "invalid"}), + (async_client, {"retry-after-ms": "invalid"}), ], ) def test_invalid_retry_after_header(client: Client, headers: httpx.Headers) -> None: