From c655fa253294bf3b8f65be1b991a93ce0890a46a Mon Sep 17 00:00:00 2001 From: Philip Gichuhi Date: Mon, 2 Dec 2024 15:03:14 +0300 Subject: [PATCH] fix: Ensures retry count is incremented based on value in retry-attempt header --- .../kiota_http/middleware/retry_handler.py | 22 ++++++++++-------- .../middleware_tests/test_retry_handler.py | 23 +++++++++++++++++++ 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/packages/http/httpx/kiota_http/middleware/retry_handler.py b/packages/http/httpx/kiota_http/middleware/retry_handler.py index e5e07ae..bc58075 100644 --- a/packages/http/httpx/kiota_http/middleware/retry_handler.py +++ b/packages/http/httpx/kiota_http/middleware/retry_handler.py @@ -14,6 +14,8 @@ from .middleware import BaseMiddleware from .options import RetryHandlerOption +RETRY_ATTEMPT = "Retry-Attempt" + class RetryHandler(BaseMiddleware): """ @@ -71,20 +73,21 @@ async def send(self, request: httpx.Request, transport: httpx.AsyncBaseTransport Sends the http request object to the next middleware or retries the request if necessary. """ response = None - retry_count = 0 - _span = self._create_observability_span(request, "RetryHandler_send") current_options = self._get_current_options(request) _span.set_attribute("com.microsoft.kiota.handler.retry.enable", True) _span.end() retry_valid = current_options.should_retry - _retry_span = self._create_observability_span( - request, f"RetryHandler_send - attempt {retry_count}" - ) + while retry_valid: response = await super().send(request, transport) - _retry_span.set_attribute(HTTP_RESPONSE_STATUS_CODE, response.status_code) # check that max retries has not been hit + retry_count = 0 if RETRY_ATTEMPT not in response.request.headers else int( + response.request.headers[RETRY_ATTEMPT] + ) + _retry_span = self._create_observability_span( + request, f"RetryHandler_send - attempt {retry_count}" + ) retry_valid = self.check_retry_valid(retry_count, current_options) # Get the delay time between retries @@ -97,13 +100,14 @@ async def send(self, request: httpx.Request, transport: httpx.AsyncBaseTransport time.sleep(delay) # increment the count for retries retry_count += 1 - request.headers.update({'retry-attempt': f'{retry_count}'}) + request.headers.update({RETRY_ATTEMPT: f'{retry_count}'}) + _retry_span.set_attribute(HTTP_RESPONSE_STATUS_CODE, response.status_code) _retry_span.set_attribute('http.request.resend_count', retry_count) continue + _retry_span.end() break if response is None: response = await super().send(request, transport) - _retry_span.end() return response def _get_current_options(self, request: httpx.Request) -> RetryHandlerOption: @@ -165,7 +169,7 @@ def check_retry_valid(self, retry_count, options): return True return False - def get_delay_time(self, retry_count, response=None, delay=0): + def get_delay_time(self, retry_count, response=None, delay=RetryHandlerOption.DEFAULT_DELAY): """ Get the time in seconds to delay between retry attempts. Respects a retry-after header in the response if provided diff --git a/packages/http/httpx/tests/middleware_tests/test_retry_handler.py b/packages/http/httpx/tests/middleware_tests/test_retry_handler.py index 6b27517..de7dd4f 100644 --- a/packages/http/httpx/tests/middleware_tests/test_retry_handler.py +++ b/packages/http/httpx/tests/middleware_tests/test_retry_handler.py @@ -238,6 +238,29 @@ def request_handler(request: httpx.Request): assert resp.status_code == 429 assert RETRY_ATTEMPT not in resp.request.headers +@pytest.mark.asyncio +async def test_max_retries_respected(): + """Test that a request is not retried more than max_retries configured""" + + def request_handler(request: httpx.Request): + if RETRY_ATTEMPT in request.headers: + return httpx.Response(200, ) + return httpx.Response( + TOO_MANY_REQUESTS, + ) + + # Retry-after value takes precedence over the RetryHandlerOption value specified here + handler = RetryHandler(RetryHandlerOption(10, 3, True)) + request = httpx.Request( + 'GET', + BASE_URL, + headers={RETRY_ATTEMPT: '5'} # value exceeds max retries configured + ) + mock_transport = httpx.MockTransport(request_handler) + resp = await handler.send(request, mock_transport) + assert resp.status_code == 200 + assert RETRY_ATTEMPT in resp.request.headers + assert resp.request.headers[RETRY_ATTEMPT] == '5' @pytest.mark.asyncio async def test_retry_options_apply_per_request():