diff --git a/src/vunnel/providers/mariner/parser.py b/src/vunnel/providers/mariner/parser.py index 5736ba42..678f3432 100644 --- a/src/vunnel/providers/mariner/parser.py +++ b/src/vunnel/providers/mariner/parser.py @@ -3,13 +3,12 @@ import os from typing import TYPE_CHECKING, Any -import requests from lxml import etree from xsdata.formats.dataclass.parsers import XmlParser from xsdata.formats.dataclass.parsers.config import ParserConfig from vunnel.providers.mariner.model import Definition, RpminfoObject, RpminfoState, RpminfoTest -from vunnel.utils import retry_with_backoff +from vunnel.utils import http from vunnel.utils.vulnerability import FixedIn, Vulnerability if TYPE_CHECKING: @@ -187,11 +186,10 @@ def __init__(self, workspace: Workspace, download_timeout: int, allow_versions: def _download(self) -> list[str]: return [self._download_version(v) for v in self.allow_versions] - @retry_with_backoff() def _download_version(self, version: str) -> str: filename = MARINER_URL_FILENAME.format(version) url = MARINER_URL_BASE.format(filename) - r = requests.get(url, timeout=self.download_timeout) + r = http.get(url, self.logger, timeout=self.download_timeout) destination = os.path.join(self.workspace.input_path, filename) with open(destination, "wb") as writer: writer.write(r.content) diff --git a/src/vunnel/utils/http.py b/src/vunnel/utils/http.py new file mode 100644 index 00000000..1c1ccc4b --- /dev/null +++ b/src/vunnel/utils/http.py @@ -0,0 +1,45 @@ +import logging +import time +from typing import Any + +import requests + +DEFAULT_TIMEOUT = 30 + + +def get( + url: str, + logger: logging.Logger, + retries: int = 5, + backoff_in_seconds: int = 3, + timeout: int = DEFAULT_TIMEOUT, + **kwargs: Any, +) -> requests.Response: + logger.debug(f"http GET {url}") + last_exception: Exception | None = None + for attempt in range(retries + 1): + if last_exception: + time.sleep(backoff_in_seconds) + try: + response = requests.get(url, timeout=timeout, **kwargs) + response.raise_for_status() + return response + except requests.exceptions.HTTPError as e: + last_exception = e + will_retry = "" + if attempt < retries: + will_retry = f" (will retry in {backoff_in_seconds} seconds) " + # HTTPError includes the attempted request, so don't include it redundantly here + logger.warning(f"attempt {attempt + 1} of {retries} failed:{will_retry}{e}") + except Exception as e: + last_exception = e + will_retry = "" + if attempt < retries: + will_retry = f" (will retry in {backoff_in_seconds} seconds) " + # this is an unexpected exception type, so include the attempted request in case the + # message from the unexpected exception doesn't. + logger.warning(f"attempt {attempt + 1} of {retries}{will_retry}: unexpected exception during GET {url}: {e}") + if last_exception: + logger.error(f"last retry of GET {url} failed with {last_exception}") + raise last_exception + raise Exception("unreachable") diff --git a/tests/conftest.py b/tests/conftest.py index 2e934cd3..cde2bf51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -207,3 +207,13 @@ def apply(d: str, name: str = ""): return path return apply + + +@pytest.fixture() +def disable_get_requests(monkeypatch): + def disabled(*args, **kwargs): + raise RuntimeError("requests disabled but HTTP GET attempted") + + from vunnel import utils + + return monkeypatch.setattr(utils.http, "get", disabled) diff --git a/tests/unit/providers/mariner/test_mariner.py b/tests/unit/providers/mariner/test_mariner.py index 0917cbce..1fd03045 100644 --- a/tests/unit/providers/mariner/test_mariner.py +++ b/tests/unit/providers/mariner/test_mariner.py @@ -5,7 +5,7 @@ import pytest from pytest_unordered import unordered -from vunnel import result, workspace +from vunnel import result, workspace, utils from vunnel.providers.mariner import Config, Provider, parser from vunnel.providers.mariner.parser import MarinerXmlFile from vunnel.utils.vulnerability import Vulnerability, FixedIn @@ -86,14 +86,6 @@ def test_parse(tmpdir, helpers, input_file, expected): assert vulnerabilities == expected -@pytest.fixture() -def disable_get_requests(monkeypatch): - def disabled(*args, **kwargs): - raise RuntimeError("requests disabled but HTTP GET attempted") - - monkeypatch.setattr(parser.requests, "get", disabled) - - def test_provider_schema(helpers, disable_get_requests, monkeypatch): workspace = helpers.provider_workspace_helper(name=Provider.name()) diff --git a/tests/unit/utils/test_http.py b/tests/unit/utils/test_http.py new file mode 100644 index 00000000..75105d64 --- /dev/null +++ b/tests/unit/utils/test_http.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import logging +import pytest +import requests +from unittest.mock import patch, MagicMock, call +from vunnel.utils import http + + +class TestGetRequests: + @pytest.fixture() + def mock_logger(self): + logger = logging.getLogger("test-http-utils") + return MagicMock(logger, autospec=True) + + @pytest.fixture() + def error_response(self): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.raise_for_status.side_effect = requests.HTTPError("HTTP ERROR") + return mock_response + + @pytest.fixture() + def success_response(self): + response = MagicMock() + response.raise_for_status = MagicMock() + response.raise_for_status.side_effect = None + response.status_code = 200 + return response + + @patch("time.sleep") + @patch("requests.get") + def test_raises_when_out_of_retries(self, mock_requests, mock_sleep, mock_logger, error_response): + mock_requests.side_effect = [Exception("could not attempt request"), error_response, error_response] + with pytest.raises(requests.HTTPError): + http.get("http://example.com/some-path", mock_logger, retries=2, backoff_in_seconds=3) + mock_logger.error.assert_called() + + @patch("time.sleep") + @patch("requests.get") + def test_succeeds_if_retries_succeed(self, mock_requests, mock_sleep, mock_logger, error_response, success_response): + mock_requests.side_effect = [error_response, success_response] + http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=22) + mock_sleep.assert_called_with(22) + mock_logger.warning.assert_called() + mock_logger.error.assert_not_called() + mock_requests.assert_called_with("http://example.com/some-path", timeout=http.DEFAULT_TIMEOUT) + + @patch("requests.get") + def test_timeout_is_passed_in(self, mock_requests, mock_logger): + http.get("http://example.com/some-path", mock_logger, timeout=12345) + mock_requests.assert_called_with("http://example.com/some-path", timeout=12345) + + @patch("time.sleep") + @patch("requests.get") + def test_sleeps_right_amount_between_retries(self, mock_requests, mock_sleep, mock_logger, error_response, success_response): + mock_requests.side_effect = [error_response, error_response, error_response, success_response] + http.get("http://example.com/some-path", mock_logger, backoff_in_seconds=123, retries=3) + assert mock_sleep.call_args_list == [call(123), call(123), call(123)] + + @patch("time.sleep") + @patch("requests.get") + def test_it_logs_the_url_on_failure(self, mock_requests, mock_sleep, mock_logger, error_response): + mock_requests.side_effect = [error_response, error_response, error_response] + url = "http://example.com/some-path" + with pytest.raises(requests.HTTPError): + http.get(url, mock_logger, retries=2) + + assert url in mock_logger.error.call_args.args[0] + + @patch("time.sleep") + @patch("requests.get") + def test_it_log_warns_errors(self, mock_requests, mock_sleep, mock_logger, error_response, success_response): + mock_requests.side_effect = [error_response, success_response] + http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=33) + assert "HTTP ERROR" in mock_logger.warning.call_args.args[0] + assert "will retry in 33 seconds" in mock_logger.warning.call_args.args[0]