From c7e5be929bbeb38d7c6291113aa3b0ce3ba50c20 Mon Sep 17 00:00:00 2001 From: Alexandre Tullot Date: Wed, 28 Jun 2023 18:54:55 +0200 Subject: [PATCH] Fix multiple calls to task_done (#19) Co-authored-by: ctmbl --- src/api/rate_limiter.py | 18 ++++++++------ tests/test_rate_limiter.py | 51 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 tests/test_rate_limiter.py diff --git a/src/api/rate_limiter.py b/src/api/rate_limiter.py index 9c613ed..c4861af 100644 --- a/src/api/rate_limiter.py +++ b/src/api/rate_limiter.py @@ -36,7 +36,7 @@ def __init__(self): else: self._max_retry = DEFAULT_MAX_RETRY - asyncio.create_task(self.handle_requests()) + self.task = asyncio.create_task(self.handle_requests()) self.logger = logging.getLogger(__name__) @@ -87,22 +87,24 @@ async def handle_requests(self): request.cookies, resp.status_code, ) - if retry_count < self._max_retry: - retry = True - retry_count += 1 - else: + if retry_count >= self._max_retry: self.logger.error( "Failed to get request after %s attempt. We could be banned :(", self._max_retry, ) raise RuntimeError("Looks like a ban to me :'(") - else: - data = resp.json() + + # Retry the request + retry = True + retry_count += 1 + continue + else: raise NotImplementedError("Only GET method implemented for now.") + # The request did pass all the tests successfully # we send back the response and trigger the event of this request - self.requests[request.key]["result"] = data + self.requests[request.key]["result"] = resp.json() self.requests[request.key]["event"].set() # finally we inform the queue of the end of the process self.queue.task_done() diff --git a/tests/test_rate_limiter.py b/tests/test_rate_limiter.py new file mode 100644 index 0000000..340bd89 --- /dev/null +++ b/tests/test_rate_limiter.py @@ -0,0 +1,51 @@ +import pytest + +from api.rate_limiter import DEFAULT_MAX_RETRY, RateLimiter + + +@pytest.mark.asyncio +async def test_default_max_retry(monkeypatch): + monkeypatch.delenv("MAX_API_RETRY", raising=False) + rate_limiter = RateLimiter() + + # pylint: disable-next=protected-access + assert rate_limiter._max_retry == DEFAULT_MAX_RETRY + + rate_limiter.task.cancel() + + +@pytest.mark.asyncio +async def test_env_max_retry(monkeypatch): + monkeypatch.setenv("MAX_API_RETRY", "42") + rate_limiter = RateLimiter() + + # pylint: disable-next=protected-access + assert rate_limiter._max_retry == 42 + + rate_limiter.task.cancel() + + +@pytest.mark.asyncio +async def test_request_passes(monkeypatch, mocker): + mocked_get = mocker.MagicMock() + resp = mocker.MagicMock() + + data = {"some data": "when request returns"} + resp.status_code = 200 + resp.json.return_value = data + mocked_get.return_value = resp + + monkeypatch.setattr("requests.get", mocked_get) + + # Trigger test + rate_limiter = RateLimiter() + + url = "url" + cookies = {"cookie": "dummy"} + result = await rate_limiter.make_request(url, cookies, "GET") + + mocked_get.assert_called_once_with(url, cookies=cookies) + assert result is data + + # Clean task properly + rate_limiter.task.cancel()