Skip to content

Commit

Permalink
refactor: use new http wrapper in nvd provider (#385)
Browse files Browse the repository at this point in the history
Include cleaning up incorrect indexing into a response object,
which static analysis could detect since the wrapper annotates its
return type.

Signed-off-by: Will Murphy <[email protected]>
  • Loading branch information
willmurphyscode authored Nov 3, 2023
1 parent b904eb6 commit bef4263
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 25 deletions.
15 changes: 7 additions & 8 deletions src/vunnel/providers/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from typing import TYPE_CHECKING, Any

import orjson
import requests

from vunnel import utils
from vunnel.utils import http

if TYPE_CHECKING:
from collections.abc import Generator

import requests


class NvdAPI:
_cve_api_url_: str = "https://services.nvd.nist.gov/rest/json/cves/2.0"
Expand Down Expand Up @@ -119,7 +120,7 @@ def _request_all_pages(

payload = orjson.loads(response.text)
if "message" in payload:
raise RuntimeError(f"API error: {response['message']}")
raise RuntimeError(f"API error: {payload['message']}")

yield payload

Expand All @@ -144,17 +145,15 @@ def _request_all_pages(

index += results_per_page

# NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second
# rolling window, so setting retry to start trying again after 30 seconds.
@utils.retry_with_backoff(backoff_in_seconds=30)
def _request(self, url: str, parameters: dict[str, str], headers: dict[str, str]) -> requests.Response:
# this is to prevent from encoding the ':' in any timestamps passed
# (e.g. prevent pubStartDate=2002-01-01T00%3A00%3A00 , want pubStartDate=2002-01-01T00:00:00)
payload_str = urllib.parse.urlencode(parameters, safe=":")

response = requests.get(url, params=payload_str, headers=headers, timeout=self.timeout)
# NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second
# rolling window, so setting retry to start trying again after 30 seconds.
response = http.get(url, self.logger, backoff_in_seconds=30, params=payload_str, headers=headers, timeout=self.timeout)
response.encoding = "utf-8"
response.raise_for_status()

return response

Expand Down
37 changes: 28 additions & 9 deletions tests/unit/providers/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pytest
from vunnel.providers.nvd import api
from vunnel.utils import http


@pytest.fixture()
Expand All @@ -25,7 +26,7 @@ def simple_mock(mocker):
),
]

return mocker.patch.object(api.requests, "get", side_effect=responses), [first_json_dict], subject
return mocker.patch.object(http, "get", side_effect=responses), [first_json_dict], subject


class TestAPI:
Expand All @@ -36,9 +37,11 @@ def test_cve_no_api_key(self, simple_mock, mocker):
vulnerabilities = list(subject.cve("CVE-2020-0000"))

assert vulnerabilities == responses
assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
backoff_in_seconds=30,
params="cveId=CVE-2020-0000",
headers={"content-type": "application/json"},
timeout=1,
Expand All @@ -51,10 +54,12 @@ def test_cve_single_cve(self, simple_mock, mocker):
vulnerabilities = list(subject.cve("CVE-2020-0000"))

assert vulnerabilities == responses
assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="cveId=CVE-2020-0000",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand Down Expand Up @@ -103,27 +108,33 @@ def test_cve_multi_page(self, mocker):
),
)

mocker.patch.object(api.requests, "get", side_effect=responses)
mock = mocker.patch.object(http, "get", side_effect=responses)

vulnerabilities = list(subject.cve())

assert vulnerabilities == json_responses
assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=3&startIndex=3",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=3&startIndex=6",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -140,10 +151,12 @@ def test_cve_pub_date_range(self, simple_mock, mocker):
)

assert vulnerabilities
assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="pubStartDate=2019-12-04T00:00:00&pubEndDate=2019-12-05T00:00:00",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -160,10 +173,12 @@ def test_cve_last_modified_date_range(self, simple_mock, mocker):
)

assert vulnerabilities
assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="lastModStartDate=2019-12-04T00:00:00&lastModEndDate=2019-12-05T00:00:00",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -177,10 +192,12 @@ def test_results_per_page(self, simple_mock, mocker):

list(subject.cve(results_per_page=5))

assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=5",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -192,10 +209,12 @@ def test_cve_history(self, simple_mock, mocker):
changes = list(subject.cve_history("CVE-2020-0000"))

assert changes
assert api.requests.get.call_args_list == [
assert mock.call_args_list == [
mocker.call(
"https://services.nvd.nist.gov/rest/json/cvehistory/2.0",
subject.logger,
params="cveId=CVE-2020-0000",
backoff_in_seconds=30,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand Down
8 changes: 0 additions & 8 deletions tests/unit/providers/nvd/test_nvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,6 @@
from vunnel.providers.nvd import api as nvd_api


@pytest.fixture()
def disable_get_requests(monkeypatch):
def disabled(*args, **kwargs):
raise RuntimeError("requests disabled but HTTP GET attempted")

monkeypatch.setattr(nvd_api.requests, "get", disabled)


@pytest.mark.parametrize(
("policy", "should_raise"),
(
Expand Down

0 comments on commit bef4263

Please sign in to comment.