From f2934fd310101ab5ae93f4728406b61478b9fd14 Mon Sep 17 00:00:00 2001 From: Stainless Bot Date: Mon, 1 Apr 2024 09:47:16 +0000 Subject: [PATCH] chore(client): validate that max_retries is not None --- src/openai/_base_client.py | 5 +++++ tests/test_client.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/openai/_base_client.py b/src/openai/_base_client.py index 7a8595c173..502ed7c7ae 100644 --- a/src/openai/_base_client.py +++ b/src/openai/_base_client.py @@ -361,6 +361,11 @@ def __init__( self._strict_response_validation = _strict_response_validation self._idempotency_header = None + if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] + raise TypeError( + "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `openai.DEFAULT_MAX_RETRIES`" + ) + def _enforce_trailing_slash(self, url: URL) -> URL: if url.raw_path.endswith(b"/"): return url diff --git a/tests/test_client.py b/tests/test_client.py index dab1cb0efd..ba85fd9d5f 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -646,6 +646,10 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) + @pytest.mark.respx(base_url=base_url) def test_default_stream_cls(self, respx_mock: MockRouter) -> None: class Model(BaseModel): @@ -1368,6 +1372,12 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + async def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + AsyncOpenAI( + base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None) + ) + @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: