Skip to content

Commit

Permalink
allow for amazon alas downloads to 403 to an extent (#564)
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Goodman <[email protected]>
  • Loading branch information
wagoodman authored May 2, 2024
1 parent ae01fc9 commit a33a36a
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/vunnel/providers/amazon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Config:
),
)
request_timeout: int = 125
max_allowed_alas_http_403: int = 25

def __post_init__(self) -> None:
self.security_advisories = {str(k): str(v) for k, v in self.security_advisories.items()}
Expand All @@ -45,6 +46,7 @@ def __init__(self, root: str, config: Config | None = None):
security_advisories=config.security_advisories,
download_timeout=config.request_timeout,
logger=self.logger,
max_allowed_alas_http_403=config.max_allowed_alas_http_403,
)

@classmethod
Expand Down
48 changes: 44 additions & 4 deletions src/vunnel/providers/amazon/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,22 @@ class Parser:
_rss_file_name_ = "{}_rss.xml"
_html_dir_name_ = "{}_html"

def __init__(self, workspace, download_timeout=125, security_advisories=None, logger=None):
def __init__( # noqa: PLR0913
self,
workspace,
download_timeout=125,
security_advisories=None,
logger=None,
max_allowed_alas_http_403=25,
):
self.workspace = workspace
self.version_url_map = security_advisories if security_advisories else amazon_security_advisories
self.download_timeout = download_timeout
self.max_allowed_alas_http_403 = max_allowed_alas_http_403
self.urls = []

self.alas_403s = []

if not logger:
logger = logging.getLogger(self.__class__.__name__)
self.logger = logger
Expand Down Expand Up @@ -85,14 +95,31 @@ def _parse_rss(self, file_path):

return sorted(alas_summaries)

def _get_alas_html(self, alas_url, alas_file, skip_if_exists=True):
def _alas_response_handler(self, response):
if response.status_code == 403:
self.alas_403s.append(response.url)
self.logger.warning(f"403 Forbidden: {response.url}")
else:
response.raise_for_status()

def _get_alas_html(self, alas_url, alas_file, skip_if_exists=True) -> str | None:
# attempt to download the alas html content
# if there is a 403, we will skip the download and track the url in self.alas_403s
# otherwise we will raise the exception
if skip_if_exists and os.path.exists(alas_file): # read alas from disk if its available
self.logger.debug(f"loading existing ALAS from {alas_file}")
with open(alas_file, encoding="utf-8") as fp:
content = fp.read()
return content # noqa: RET504
try:
r = http.get(alas_url, self.logger, timeout=self.download_timeout)
r = http.get(alas_url, self.logger, timeout=self.download_timeout, status_handler=self._alas_response_handler)
if r.status_code == 403:
if len(self.alas_403s) > self.max_allowed_alas_http_403:
raise ValueError(
f"exceeded maximum allowed 403 responses ({self.max_allowed_alas_http_403}) from ALAS requests",
)

return None
content = r.text
with open(alas_file, "w", encoding="utf-8") as fp:
fp.write(content)
Expand All @@ -116,6 +143,15 @@ def get_package_name_version(pkg):
return AlasFixedIn(pkg=name, ver=version)

def get(self, skip_if_exists=False):
try:
yield from self._get(skip_if_exists)
finally:
if self.alas_403s:
self.logger.warning(f"failed to fetch {len(self.alas_403s)} ALAS entries due to HTTP 403 response code")
for url in self.alas_403s:
self.logger.warning(f" - {url}")

def _get(self, skip_if_exists):
for version, url in self.version_url_map.items():
rss_file = os.path.join(self.workspace.input_path, self._rss_file_name_.format(version))
html_dir = os.path.join(self.workspace.input_path, self._html_dir_name_.format(version))
Expand All @@ -134,7 +170,11 @@ def get(self, skip_if_exists=False):
for alas in alas_summaries:
# download alas html content
alas_file = os.path.join(html_dir, alas.id)
html_content = self._get_alas_html(alas.url, alas_file)
html_content = self._get_alas_html(alas.url, alas_file, skip_if_exists=skip_if_exists)

if html_content is None:
self.logger.warning(f"skipping {alas.id}")
continue

# parse alas html for fixes
parser = PackagesHTMLParser()
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test-fixtures/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ providers:
amazon:
runtime: *runtime
request_timeout: 20
max_allowed_alas_http_403: 33
security_advisories:
42: "https://alas.aws.amazon.com/AL2/alas-42.rss"
chainguard:
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def test_config(monkeypatch) -> None:
retry_delay: 5
result_store: sqlite
amazon:
max_allowed_alas_http_403: 25
request_timeout: 125
runtime:
existing_input: keep
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def test_full_config(helpers):
},
runtime=runtime_cfg,
request_timeout=20,
max_allowed_alas_http_403=33,
),
chainguard=providers.chainguard.Config(
runtime=runtime_cfg,
Expand Down
42 changes: 41 additions & 1 deletion tests/unit/providers/amazon/test_amazon.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import shutil
from requests.exceptions import HTTPError
from unittest.mock import Mock

import pytest
from vunnel import result, workspace
from vunnel.utils import http
from vunnel.utils.http import requests
from vunnel.providers.amazon import Config, Provider, parser


Expand Down Expand Up @@ -76,6 +78,44 @@ def test_get_pkg_name_version(self):
b = parser.Parser.get_package_name_version("java-1.8.0-openjdk-1.8.0.161-0.b14.amzn2.x86_64")
assert a == b

def test_get_alas_html_403(self, helpers, monkeypatch, tmpdir):
# write a mock such that any http.get call will return a response with status code 403
def mock_get(*args, **kwargs):
return Mock(status_code=403)

monkeypatch.setattr(requests, "get", mock_get)

alas_file = tmpdir.join("alas.html")

p = parser.Parser(workspace=workspace.Workspace(helpers.local_dir("test-fixtures"), "test", create=True))
alas = p._get_alas_html("https://example.com", alas_file)
assert alas is None

def test_get_alas_html_raises_over_threshold(self, helpers, monkeypatch, tmpdir):
# write a mock such that any http.get call will return a response with status code 403
url = "https://example.com"

def mock_get(*args, **kwargs):
return Mock(status_code=403, url=url)

monkeypatch.setattr(requests, "get", mock_get)

alas_file = tmpdir.join("alas.html")

p = parser.Parser(workspace=workspace.Workspace(helpers.local_dir("test-fixtures"), "test", create=True))
p.max_allowed_alas_http_403 = 2

# assert does not raise when at the threshold
p.alas_403s = ["something"]
p._get_alas_html(url, alas_file)
assert p.alas_403s == ["something", url]

# assert raises when above the threshold
with pytest.raises(ValueError):
p._get_alas_html(url, alas_file)

assert p.alas_403s == ["something", url, url]


def test_provider_schema(helpers, disable_get_requests, monkeypatch):
workspace = helpers.provider_workspace_helper(name=Provider.name())
Expand Down

0 comments on commit a33a36a

Please sign in to comment.