From 3ef26a6e7cde66ae9222e006dd8ddc9b2e1a1459 Mon Sep 17 00:00:00 2001 From: Stainless Bot <107565488+stainless-bot@users.noreply.github.com> Date: Mon, 1 Apr 2024 10:47:01 +0100 Subject: [PATCH] chore(client): validate that max_retries is not None (#430) --- src/anthropic/_base_client.py | 5 +++++ tests/test_client.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/src/anthropic/_base_client.py b/src/anthropic/_base_client.py index 7a8595c1..d4215673 100644 --- a/src/anthropic/_base_client.py +++ b/src/anthropic/_base_client.py @@ -361,6 +361,11 @@ def __init__( self._strict_response_validation = _strict_response_validation self._idempotency_header = None + if max_retries is None: # pyright: ignore[reportUnnecessaryComparison] + raise TypeError( + "max_retries cannot be None. If you want to disable retries, pass `0`; if you want unlimited retries, pass `math.inf` or a very high number; if you want the default behavior, pass `anthropic.DEFAULT_MAX_RETRIES`" + ) + def _enforce_trailing_slash(self, url: URL) -> URL: if url.raw_path.endswith(b"/"): return url diff --git a/tests/test_client.py b/tests/test_client.py index 1ce6332a..43214bd5 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -753,6 +753,10 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + Anthropic(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)) + @pytest.mark.respx(base_url=base_url) def test_default_stream_cls(self, respx_mock: MockRouter) -> None: class Model(BaseModel): @@ -1578,6 +1582,12 @@ class Model(BaseModel): assert isinstance(exc.value.__cause__, ValidationError) + async def test_client_max_retries_validation(self) -> None: + with pytest.raises(TypeError, match=r"max_retries cannot be None"): + AsyncAnthropic( + base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None) + ) + @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: