diff --git a/tests/test_testing_utils.py b/tests/test_testing_utils.py index 17a6f2f9e..10a080cbb 100644 --- a/tests/test_testing_utils.py +++ b/tests/test_testing_utils.py @@ -3,6 +3,7 @@ import platform import sys from pathlib import Path +from urllib.error import URLError import numpy as np import pytest @@ -139,24 +140,26 @@ def test_release_notes_file_not_implemented(self, tmp_path): class TestTestingFileAccessors: def test_unsafe_urls(self): with pytest.raises( - ValueError, match="GitHub URL not safe: 'ftp://domain.does.not.exist/'." + ValueError, match="GitHub URL not secure: 'ftp://domain.does.not.exist/'." ): utilities.open_dataset( "doesnt_exist.nc", github_url="ftp://domain.does.not.exist/" ) with pytest.raises( - ValueError, match="OPeNDAP URL not safe: 'ftp://domain.does.not.exist/'." + OSError, + match="OPeNDAP file not read. Verify that the service is available: " + "'https://seemingly.trustworthy.com/doesnt_exist.nc'", ): utilities.open_dataset( - "doesnt_exist.nc", dap_url="ftp://domain.does.not.exist/" + "doesnt_exist.nc", dap_url="https://seemingly.trustworthy.com/" ) - def test_bad_opendap_url(self): + def test_malicious_urls(self): with pytest.raises( - OSError, - match="OPeNDAP file not read. Verify that the service is available.", + URLError, + match="urlopen error OPeNDAP URL is not well-formed: 'doesnt_exist.nc'", ): utilities.open_dataset( - "doesnt_exist.nc", dap_url="https://dap.service.does.not.exist/" + "doesnt_exist.nc", dap_url="Robert'); DROP TABLE STUDENTS; --" ) diff --git a/xclim/testing/utils.py b/xclim/testing/utils.py index 009fffda6..b7e866f61 100644 --- a/xclim/testing/utils.py +++ b/xclim/testing/utils.py @@ -21,7 +21,7 @@ from shutil import copy from typing import TextIO from urllib.error import HTTPError, URLError -from urllib.parse import urljoin +from urllib.parse import urljoin, urlparse from urllib.request import urlopen, urlretrieve import pandas as pd @@ -62,6 +62,7 @@ __all__ = [ "_default_cache_dir", + "audit_url", "get_file", "get_local_testdata", "list_datasets", @@ -73,12 +74,33 @@ def file_md5_checksum(f_name): - hash_md5 = hashlib.md5() # nosec + hash_md5 = hashlib.md5() # noqa: S324 with open(f_name, "rb") as f: hash_md5.update(f.read()) return hash_md5.hexdigest() +def audit_url(url: str, context: str = None) -> str: + """Check if the URL is well-formed. + + Raises + ------ + URLError + If the URL is not well-formed. + """ + msg = "" + result = urlparse(url) + if result.scheme == "http": + msg = f"{context if context else ''} URL is not using secure HTTP: '{url}'".strip() + if not all([result.scheme, result.netloc]): + msg = f"{context if context else ''} URL is not well-formed: '{url}'".strip() + + if msg: + logger.error(msg) + raise URLError(msg) + return url + + def get_file( name: str | os.PathLike[str] | Sequence[str | os.PathLike[str]], github_url: str = "https://github.com/Ouranosinc/xclim-testdata", @@ -197,8 +219,8 @@ def _get( md5_name = fullname.with_suffix(f"{suffix}.md5") md5_file = cache_dir / branch / md5_name - if not github_url.lower().startswith("http"): - raise ValueError(f"GitHub URL not safe: '{github_url}'.") + if not github_url.startswith("https"): + raise ValueError(f"GitHub URL not secure: '{github_url}'.") if local_file.is_file(): local_md5 = file_md5_checksum(local_file) @@ -206,7 +228,7 @@ def _get( url = "/".join((github_url, "raw", branch, md5_name.as_posix())) msg = f"Attempting to fetch remote file md5: {md5_name.as_posix()}" logger.info(msg) - urlretrieve(url, md5_file) # nosec + urlretrieve(audit_url(url), md5_file) # noqa: S310 with open(md5_file) as f: remote_md5 = f.read() if local_md5.strip() != remote_md5.strip(): @@ -241,7 +263,7 @@ def _get( msg = f"Fetching remote file: {fullname.as_posix()}" logger.info(msg) try: - urlretrieve(url, local_file) # nosec + urlretrieve(audit_url(url), local_file) # noqa: S310 except HTTPError as e: msg = f"{fullname.as_posix()} not accessible in remote repository. Aborting file retrieval." raise FileNotFoundError(msg) from e @@ -262,7 +284,7 @@ def _get( url = "/".join((github_url, "raw", branch, md5_name.as_posix())) msg = f"Fetching remote file md5: {md5_name.as_posix()}" logger.info(msg) - urlretrieve(url, md5_file) # nosec + urlretrieve(audit_url(url), md5_file) # noqa: S310 except (HTTPError, URLError) as e: msg = ( f"{md5_name.as_posix()} not accessible online. " @@ -337,21 +359,17 @@ def open_dataset( suffix = ".nc" fullname = name.with_suffix(suffix) - if not github_url.lower().startswith("http"): - raise ValueError(f"GitHub URL not safe: '{github_url}'.") - if dap_url is not None: - if not dap_url.lower().startswith("http"): - raise ValueError(f"OPeNDAP URL not safe: '{dap_url}'.") - - dap_file = urljoin(dap_url, str(name)) + dap_file_address = urljoin(dap_url, str(name)) try: - ds = _open_dataset(dap_file, **kwargs) + ds = _open_dataset(audit_url(dap_file_address, context="OPeNDAP"), **kwargs) return ds - except OSError as err: - msg = "OPeNDAP file not read. Verify that the service is available." + except URLError: + raise + except OSError: + msg = f"OPeNDAP file not read. Verify that the service is available: '{dap_file_address}'" logger.error(msg) - raise OSError(msg) from err + raise OSError(msg) local_file = _get( fullname=fullname, @@ -378,8 +396,8 @@ def list_datasets(github_repo="Ouranosinc/xclim-testdata", branch="main"): This uses an unauthenticated call to GitHub's REST API, so it is limited to 60 requests per hour (per IP). A single call of this function triggers one request per subdirectory, so use with parsimony. """ - with urlopen( # nosec - f"https://api.github.com/repos/{github_repo}/contents?ref={branch}" + with urlopen( # noqa: S310 + audit_url(f"https://api.github.com/repos/{github_repo}/contents?ref={branch}") ) as res: base = json.loads(res.read().decode()) records = [] @@ -387,7 +405,7 @@ def list_datasets(github_repo="Ouranosinc/xclim-testdata", branch="main"): if folder["path"].startswith(".") or folder["size"] > 0: # drop hidden folders and other files. continue - with urlopen(folder["url"]) as res: # nosec + with urlopen(audit_url(folder["url"])) as res: # noqa: S310 listing = json.loads(res.read().decode()) for file in listing: if file["path"].endswith(".nc"):