diff --git a/src/vunnel/providers/nvd/api.py b/src/vunnel/providers/nvd/api.py index 24e05872..ecc16894 100644 --- a/src/vunnel/providers/nvd/api.py +++ b/src/vunnel/providers/nvd/api.py @@ -6,13 +6,14 @@ from typing import TYPE_CHECKING, Any import orjson -import requests -from vunnel import utils +from vunnel.utils import http if TYPE_CHECKING: from collections.abc import Generator + import requests + class NvdAPI: _cve_api_url_: str = "https://services.nvd.nist.gov/rest/json/cves/2.0" @@ -119,7 +120,7 @@ def _request_all_pages( payload = orjson.loads(response.text) if "message" in payload: - raise RuntimeError(f"API error: {response['message']}") + raise RuntimeError(f"API error: {payload['message']}") yield payload @@ -144,17 +145,15 @@ def _request_all_pages( index += results_per_page - # NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second - # rolling window, so setting retry to start trying again after 30 seconds. - @utils.retry_with_backoff(backoff_in_seconds=30) def _request(self, url: str, parameters: dict[str, str], headers: dict[str, str]) -> requests.Response: # this is to prevent from encoding the ':' in any timestamps passed # (e.g. prevent pubStartDate=2002-01-01T00%3A00%3A00 , want pubStartDate=2002-01-01T00:00:00) payload_str = urllib.parse.urlencode(parameters, safe=":") - response = requests.get(url, params=payload_str, headers=headers, timeout=self.timeout) + # NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second + # rolling window, so setting retry to start trying again after 30 seconds. + response = http.get(url, self.logger, backoff_in_seconds=30, params=payload_str, headers=headers, timeout=self.timeout) response.encoding = "utf-8" - response.raise_for_status() return response diff --git a/tests/unit/providers/nvd/test_api.py b/tests/unit/providers/nvd/test_api.py index 08834946..d6019c90 100644 --- a/tests/unit/providers/nvd/test_api.py +++ b/tests/unit/providers/nvd/test_api.py @@ -5,6 +5,7 @@ import pytest from vunnel.providers.nvd import api +from vunnel.utils import http @pytest.fixture() @@ -25,7 +26,7 @@ def simple_mock(mocker): ), ] - return mocker.patch.object(api.requests, "get", side_effect=responses), [first_json_dict], subject + return mocker.patch.object(http, "get", side_effect=responses), [first_json_dict], subject class TestAPI: @@ -36,9 +37,11 @@ def test_cve_no_api_key(self, simple_mock, mocker): vulnerabilities = list(subject.cve("CVE-2020-0000")) assert vulnerabilities == responses - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, + backoff_in_seconds=30, params="cveId=CVE-2020-0000", headers={"content-type": "application/json"}, timeout=1, @@ -51,10 +54,12 @@ def test_cve_single_cve(self, simple_mock, mocker): vulnerabilities = list(subject.cve("CVE-2020-0000")) assert vulnerabilities == responses - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="cveId=CVE-2020-0000", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -103,27 +108,33 @@ def test_cve_multi_page(self, mocker): ), ) - mocker.patch.object(api.requests, "get", side_effect=responses) + mock = mocker.patch.object(http, "get", side_effect=responses) vulnerabilities = list(subject.cve()) assert vulnerabilities == json_responses - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="resultsPerPage=3&startIndex=3", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="resultsPerPage=3&startIndex=6", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -140,10 +151,12 @@ def test_cve_pub_date_range(self, simple_mock, mocker): ) assert vulnerabilities - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="pubStartDate=2019-12-04T00:00:00&pubEndDate=2019-12-05T00:00:00", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -160,10 +173,12 @@ def test_cve_last_modified_date_range(self, simple_mock, mocker): ) assert vulnerabilities - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="lastModStartDate=2019-12-04T00:00:00&lastModEndDate=2019-12-05T00:00:00", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -177,10 +192,12 @@ def test_results_per_page(self, simple_mock, mocker): list(subject.cve(results_per_page=5)) - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", + subject.logger, params="resultsPerPage=5", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -192,10 +209,12 @@ def test_cve_history(self, simple_mock, mocker): changes = list(subject.cve_history("CVE-2020-0000")) assert changes - assert api.requests.get.call_args_list == [ + assert mock.call_args_list == [ mocker.call( "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", + subject.logger, params="cveId=CVE-2020-0000", + backoff_in_seconds=30, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), diff --git a/tests/unit/providers/nvd/test_nvd.py b/tests/unit/providers/nvd/test_nvd.py index 63757bca..71c5db0f 100644 --- a/tests/unit/providers/nvd/test_nvd.py +++ b/tests/unit/providers/nvd/test_nvd.py @@ -9,14 +9,6 @@ from vunnel.providers.nvd import api as nvd_api -@pytest.fixture() -def disable_get_requests(monkeypatch): - def disabled(*args, **kwargs): - raise RuntimeError("requests disabled but HTTP GET attempted") - - monkeypatch.setattr(nvd_api.requests, "get", disabled) - - @pytest.mark.parametrize( ("policy", "should_raise"), (