Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check registry file hash before downloading #93

Merged
merged 15 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 60 additions & 10 deletions src/lephare/data_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,34 +53,84 @@ def filter_files_by_prefix(file_path, target_prefixes):
return matching_lines


def _check_registry_is_latest_version(remote_registry_url, local_registry_file):
"""Checks whether the local registry file is the latest version compared to a remote registry.

Parameters
----------
remote_registry_url : str
The URL to the remote registry file, used to construct the URL to fetch the remote hash.
local_registry_file : str
The path to the local registry file whose up-to-date status is to be checked.

Returns
-------
bool
Returns True if the local registry file is up to date, otherwise False.

Notes
-----
We make the assumption that the hash file for the registry will be stored in
the same directory as the registry file, with the same name (sans extension)
plus "_hash.sha256".

Raises
------
Exception
If there is any problem fetching the registry hash file, including network issues,
server errors, or other HTTP errors.
"""
local_registry_hash = pooch.file_hash(local_registry_file, alg="sha256")
remote_hash_url = os.path.splitext(remote_registry_url)[0] + "_hash.sha256"

remote_hash_response = requests.get(remote_hash_url, timeout=60)
remote_hash_response.raise_for_status() # Raise exceptions for non-200 status codes

if remote_hash_response.text.strip() == local_registry_hash:
print(f"Local registry file is up to date: {local_registry_file}")
return True
OliviaLynn marked this conversation as resolved.
Show resolved Hide resolved
else:
print(f"Local registry file is not up to date: {local_registry_file}")
return False


def download_registry_from_github(url="", outfile=""):
"""Fetch the contents of a file from a GitHub repository.

Parameters
----------
url : str
The URL of the registry file. Defaults to a "data-registry.txt" file at
The URL of the registry file. Defaults to a "data_registry.txt" file at
DEFAULT_BASE_DATA_URL.
outfile : str
The path where the file will be saved. Defaults to DEFAULT_REGISTRY_FILE.

Raises
------
Exception
If the file cannot be fetched from the URL.
If there is any problem fetching the registry hash file or full registry file,
including network issues, server errors, or other HTTP errors.
"""
remote_registry_name = "data_registry.txt"

# Assign defaults if keywords left blank
if url == "":
url = urljoin(DEFAULT_BASE_DATA_URL, "data-registry.txt")
url = urljoin(DEFAULT_BASE_DATA_URL, remote_registry_name)
if outfile == "":
outfile = DEFAULT_REGISTRY_FILE

response = requests.get(url, timeout=60)
if response.status_code == 200:
with open(outfile, "w", encoding="utf-8") as file:
file.write(response.text)
print(f"Registry file downloaded and saved as {outfile}.")
else:
raise requests.exceptions.HTTPError(f"Failed to fetch file: {response.status_code}")
# If local registry hash matches remote hash, our registry is already up-to-date:
if os.path.isfile(outfile) and _check_registry_is_latest_version(url, outfile):
return

# Download the registry file
response = requests.get(url, timeout=120)
response.raise_for_status() # Raise exceptions for non-200 status codes

with open(outfile, "w", encoding="utf-8") as file:
file.write(response.text)

print(f"Registry file downloaded and saved as {outfile}.")


def read_list_file(list_file, prefix=""):
Expand Down
22 changes: 0 additions & 22 deletions tests/lephare/test_data_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import mock_open, patch

import pytest
import requests
from lephare.data_retrieval import (
DEFAULT_BASE_DATA_URL,
DEFAULT_LOCAL_DATA_PATH,
Expand All @@ -13,7 +11,6 @@
_create_directories_from_files,
download_all_files,
download_file,
download_registry_from_github,
filter_files_by_prefix,
make_default_retriever,
make_retriever,
Expand All @@ -28,25 +25,6 @@ def test_filter_file_by_prefix(test_data_dir):
assert filter_files_by_prefix(file_path, target_prefixes) == expected_lines


@patch("requests.get")
def test_download_registry_from_github_success(mock_get):
mock_get.return_value.status_code = 200
mock_get.return_value.text = "file1\nfile2\nfile3"

with tempfile.TemporaryDirectory() as tmpdir:
download_registry_from_github(outfile=os.path.join(tmpdir, "registry.txt"))

with open(os.path.join(tmpdir, "registry.txt"), "r") as file:
assert file.read() == "file1\nfile2\nfile3"


@patch("requests.get")
def test_download_registry_from_github_failure(mock_get):
mock_get.return_value.status_code = 404
with pytest.raises(requests.exceptions.HTTPError):
download_registry_from_github()


def test_read_list_file(test_data_dir):
file_path = os.path.join(test_data_dir, "test_file_names.list")
expected_files = ["prefix1_file1", "prefix2_file2"]
Expand Down
188 changes: 188 additions & 0 deletions tests/lephare/test_data_retrieval_registry.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I like the way you've documented the logic and the tests that cover it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""There's a lot of branching logic in download_registry_from_github.

I did my best to keep it simple, but it's still a little tricky.

Here's how it works:

Call download_registry_from_github
1. Local registry file does not exist → Download remote registry file
1. Successfully downloaded registry file → Exit (now have local registry)
2. Fail to download registry file → raise Exception
2. Local registry file exists
→ Need to check if local registry is up to date
→ Call _check_registry_is_latest_version
1. True → Exit (confirmed local registry is up to date)
2. False → Download updated version
1. Successfully downloaded registry file → Exit (local registry updated)
2. Fail to download registry file → raise Exception
3. Failed to download hash file → raise Exception
"""

import os
import tempfile
from unittest.mock import Mock, patch

import pytest
import requests
from lephare.data_retrieval import (
download_registry_from_github,
)


def test_download_registry_success():
# 1. Local registry does not exist (so no mocking needed here)
# 1. Successfully downloaded registry file

# Mock remote registry file
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "file1\nfile2\nfile3"
with patch("requests.get", return_value=mock_response) as mock_get_remote_registry: # noqa: F841
with tempfile.TemporaryDirectory() as tmp_dir:
registry_outfile = os.path.join(tmp_dir, "registry.txt")
download_registry_from_github(outfile=registry_outfile)
# Check that we can open it (and it contains expected content)
with open(registry_outfile, "r") as file:
assert file.read() == "file1\nfile2\nfile3"


def test_download_registry_failure():
# 1. Local registry does not exist (no mocking needed)
# 2. Fail to download registry file

# Mock failed registry file download
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"404 Client Error: Not Found for url"
)
with patch("requests.get", return_value=mock_response) as mock_get_remote_registry: # noqa: F841
with pytest.raises(requests.exceptions.HTTPError):
download_registry_from_github()


def test_update_registry_hash_matches():
# 2. Local registry exists
# 1. _check_registry_is_latest_version == True

# Mock the local registry file existing
with patch("os.path.isfile", return_value=True) as mock_local_registry_existing: # noqa: F841
# Mock local registry having a certain pooch hash
with patch("pooch.file_hash", return_value="registryhash123") as mock_local_registry_hash: # noqa: F841
# Mock getting the remote registry hash file
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "registryhash123"
with patch("requests.get", return_value=mock_response) as mock_get_remote_hash_file:
# Call the function
with tempfile.TemporaryDirectory() as tmp_dir:
registry_outfile = os.path.join(tmp_dir, "registry.txt")
download_registry_from_github(
url="http://example.com/data_registry.txt", outfile=registry_outfile
)
# Assert that we only make 1 request for the hash file, and that we used the right url:
assert mock_get_remote_hash_file.call_count == 1
assert (
mock_get_remote_hash_file.call_args[0][0]
== "http://example.com/data_registry_hash.sha256"
)


def test_update_registry_hash_mismatches():
# 2. Local registry exists
# 2. _check_registry_is_latest_version == False
# 1. Successfully downloaded registry file

# Mock the local registry file existing
with patch("os.path.isfile", return_value=True) as mock_local_registry_existing: # noqa: F841
# Mock local registry having a certain pooch hash
with patch("pooch.file_hash", return_value="registryhash123") as mock_local_registry_hash: # noqa: F841
# Mock getting the remote hash/registry files
mock_hash_response = Mock()
mock_hash_response.status_code = 200
mock_hash_response.text = "hash_doesn't_match123"

mock_registry_response = Mock()
mock_registry_response.status_code = 200
mock_registry_response.text = "file1\nfile2\nfile3"

def which_mock_get(*args, **kwargs):
url = args[0]
if "hash.sha256" in url:
return mock_hash_response
else:
return mock_registry_response

with patch("requests.get", side_effect=which_mock_get) as mock_remote_files_get:
# Call the function
with tempfile.TemporaryDirectory() as tmp_dir:
registry_outfile = os.path.join(tmp_dir, "registry.txt")
download_registry_from_github(
url="http://example.com/data_registry.txt", outfile=registry_outfile
)
# Checks:
# One call to download the hash file, one call to download the full registry file:
assert mock_remote_files_get.call_count == 2
# The following set of [0][0][0] and such is because the call args list is
# [call('http://example.com/data_registry_hash.sha256', timeout=60),
# call('http://example.com/data_registry.txt', timeout=60)]
# and we're only interested in checking the urls:
assert (
mock_remote_files_get.call_args_list[0][0][0]
== "http://example.com/data_registry_hash.sha256"
)
assert (
mock_remote_files_get.call_args_list[1][0][0]
== "http://example.com/data_registry.txt"
)


def test_update_registry_hash_mismatches_and_download_fails():
# 2. Local registry exists
# 2. _check_registry_is_latest_version == False
# 2. Fail to download registry file

# Mock the local registry file existing
with patch("os.path.isfile", return_value=True) as mock_local_registry_existing: # noqa: F841
# Mock local registry having a certain pooch hash
with patch("pooch.file_hash", return_value="registryhash123") as mock_local_registry_hash: # noqa: F841
# Mock getting the remote hash/registry files
mock_hash_response = Mock()
mock_hash_response.status_code = 200
mock_hash_response.text = "hash_doesn't_match123"

mock_registry_response = Mock()
mock_registry_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"404 Client Error: Not Found for url"
)

def which_mock_get(*args, **kwargs):
url = args[0]
if "hash.sha256" in url:
return mock_hash_response
else:
return mock_registry_response

with patch("requests.get", side_effect=which_mock_get) as mock_remote_files_get: # noqa: F841
# Check that we raise HTTPError as expected
with pytest.raises(requests.exceptions.HTTPError):
download_registry_from_github()


def test_update_registry_hash_download_fails():
# 2. Local registry exists
# 3. Fail to download registry hash file

# Mock the local registry file existing
with patch("os.path.isfile", return_value=True) as mock_local_registry_existing: # noqa: F841
# Mock local registry having a certain pooch hash
with patch("pooch.file_hash", return_value="registryhash123") as mock_local_registry_hash: # noqa: F841
# Mock getting the remote hash/registry files
mock_hash_response = Mock()
mock_hash_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
"404 Client Error: Not Found for url"
)

with patch("requests.get", return_value=mock_hash_response) as mock_get_remote_hash_file: # noqa: F841
# Check that we get the expected exception
with pytest.raises(requests.exceptions.HTTPError):
download_registry_from_github()