diff --git a/q2_moshpit/_utils.py b/q2_moshpit/_utils.py index 0bce8a0a..f93efe17 100644 --- a/q2_moshpit/_utils.py +++ b/q2_moshpit/_utils.py @@ -5,7 +5,6 @@ # # The full license is in the file LICENSE, distributed with this software. # ---------------------------------------------------------------------------- -from qiime2.core.exceptions import ValidationError import subprocess import hashlib from typing import List @@ -74,24 +73,11 @@ def _process_common_input_params(processing_func, params: dict) -> List[str]: return processed_args -def colorify(string): +def colorify(string: str): return "%s%s%s" % ('\033[1;32m', string, "\033[0m") -def compare_md5_hashes(expected_hash: str, path_to_file: str): - observed_hash = calculate_md5_from_file(path_to_file) - if observed_hash != expected_hash: - raise ValidationError( - "Download error. Data possibly corrupted.\n" - f"{path_to_file} has an unexpected MD5 hash.\n\n" - "Expected hash:\n" - f"{expected_hash}\n\n" - "Observed hash:\n" - f"{observed_hash}" - ) - - -def calculate_md5_from_file(file_path): +def _calculate_md5_from_file(file_path: str) -> str: md5_hash = hashlib.md5() with open(file_path, 'rb') as f: # Read the file in chunks to handle large files diff --git a/q2_moshpit/eggnog/_dbs.py b/q2_moshpit/eggnog/_dbs.py index dec0e953..65e244f7 100644 --- a/q2_moshpit/eggnog/_dbs.py +++ b/q2_moshpit/eggnog/_dbs.py @@ -6,15 +6,17 @@ # The full license is in the file LICENSE, distributed with this software. # ---------------------------------------------------------------------------- import os +import shutil import pandas as pd +from qiime2.core.exceptions import ValidationError from q2_types.feature_data import ProteinSequencesDirectoryFormat -import shutil from q2_types_genomics.reference_db import ( EggnogRefDirFmt, DiamondDatabaseDirFmt, NCBITaxonomyDirFmt, EggnogProteinSequencesDirFmt ) from .._utils import ( - run_command, _process_common_input_params, colorify, compare_md5_hashes + run_command, _process_common_input_params, colorify, + _calculate_md5_from_file ) from ._utils import _parse_build_diamond_db_params @@ -241,29 +243,30 @@ def _validate_taxon_id(eggnog_proteins, taxon): def fetch_ncbi_taxonomy() -> NCBITaxonomyDirFmt: """ - Script fetches 3 files from the NCBI server and puts them into the folder of - a NCBITaxonomyDirFmt object. + Script fetches 3 files from the NCBI server and puts them into the folder + of a NCBITaxonomyDirFmt object. """ - # Initialize output object and paths ncbi_data = NCBITaxonomyDirFmt() zip_path = os.path.join(str(ncbi_data), "taxdmp.zip") proteins_path = os.path.join(str(ncbi_data), "prot.accession2taxid.gz") - # Download zip file + MD5 file + # Download dump zip file + MD5 file print(colorify("Downloading *.dmp files...")) - for ext in ["", ".md5"]: - # Download MD5 - run_command( - cmd=[ - "wget", "-O", f"{zip_path}{ext}", - f"ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip{ext}" - ] - ) + run_command( + cmd=[ + "wget", "-O", f"{zip_path}", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip" + ] + ) + run_command( + cmd=[ + "wget", "-O", f"{zip_path}.md5", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip.md5" + ] + ) - # Collect and compare md5 hashes _collect_and_compare_md5(f"{zip_path}.md5", zip_path) - # Unzip run_command( cmd=[ "unzip", "-j", zip_path, "names.dmp", "nodes.dmp", @@ -271,24 +274,27 @@ def fetch_ncbi_taxonomy() -> NCBITaxonomyDirFmt: ] ) - # Remove zip file - run_command(cmd=["rm", zip_path]) + os.remove(zip_path) # Download proteins + MD5 file print(colorify("Downloading proteins file (~8 GB)...")) - for ext in ["", ".md5"]: - run_command( - cmd=[ - "wget", "-O", f"{proteins_path}{ext}", - "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/accession2taxid/" - f"prot.accession2taxid.gz{ext}" - ] - ) + run_command( + cmd=[ + "wget", "-O", f"{proteins_path}", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/accession2taxid/" + "prot.accession2taxid.gz" + ] + ) + run_command( + cmd=[ + "wget", "-O", f"{proteins_path}.md5", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/accession2taxid/" + "prot.accession2taxid.gz.md5" + ] + ) - # Collect and compare md5 hashes _collect_and_compare_md5(f"{proteins_path}.md5", proteins_path) - # Return object print(colorify( "Done! Moving data from temporary directory to final location..." )) @@ -296,13 +302,22 @@ def fetch_ncbi_taxonomy() -> NCBITaxonomyDirFmt: def _collect_and_compare_md5(path_to_md5: str, path_to_file: str): + # Read in hash from md5 file with open(path_to_md5, 'r') as f: - # Read the first line - first_line = f.readline().strip() - # Split the line into hash and file name - md5_hash, _ = first_line.split(' ', 1) - # Compare - compare_md5_hashes(md5_hash, path_to_file) + expected_hash = f.readline().strip().split(maxsplit=1)[0] + + # Calculate hash from file + observed_hash = _calculate_md5_from_file(path_to_file) + + if observed_hash != expected_hash: + raise ValidationError( + "Download error. Data possibly corrupted.\n" + f"{path_to_file} has an unexpected MD5 hash.\n\n" + "Expected hash:\n" + f"{expected_hash}\n\n" + "Observed hash:\n" + f"{observed_hash}" + ) # If no exception is raised, remove md5 file - run_command(cmd=["rm", path_to_md5]) + os.remove(path_to_md5) diff --git a/q2_moshpit/eggnog/tests/test_dbs.py b/q2_moshpit/eggnog/tests/test_dbs.py index a32fb702..6529d675 100644 --- a/q2_moshpit/eggnog/tests/test_dbs.py +++ b/q2_moshpit/eggnog/tests/test_dbs.py @@ -154,66 +154,80 @@ def test_fetch_eggnog_fasta(self, subp_run): @patch("q2_moshpit.eggnog._dbs._collect_and_compare_md5") @patch("subprocess.run") - def test_fetch_ncbi_taxonomy(self, subp_run, cc_md5): + @patch("os.remove") + def test_fetch_ncbi_taxonomy(self, mock_os_rm, mock_run, mock_md5): # Call function. Patching will make sure nothing is actually ran ncbi_data = fetch_ncbi_taxonomy() zip_path = os.path.join(str(ncbi_data), "taxdmp.zip") proteins_path = os.path.join(str(ncbi_data), "prot.accession2taxid.gz") # Check that command was called in the expected way - I_call, II_call = [ + expected_calls = [ call( [ - "wget", "-O", f"{zip_path}{ext}", - f"ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip{ext}" + "wget", "-O", f"{zip_path}", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip" ], check=True - ) - for ext in ["", ".md5"] - ] - III_call = call(f"{zip_path}.md5", zip_path) - IV_call = call( - [ - "unzip", "-j", zip_path, "names.dmp", "nodes.dmp", - "-d", str(ncbi_data) - ], - check=True, - ) - V_call = call(["rm", zip_path], check=True) - VI_call, VII_call = [ + ), + call( + [ + "wget", "-O", f"{zip_path}.md5", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/taxdmp.zip.md5" + ], + check=True + ), + call( + [ + "unzip", "-j", zip_path, "names.dmp", "nodes.dmp", + "-d", str(ncbi_data) + ], + check=True, + ), call( [ - "wget", "-O", f"{proteins_path}{ext}", + "wget", "-O", f"{proteins_path}", "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/accession2taxid/" - f"prot.accession2taxid.gz{ext}" + "prot.accession2taxid.gz" + ], + check=True + ), + call( + [ + "wget", "-O", f"{proteins_path}.md5", + "ftp://ftp.ncbi.nlm.nih.gov/pub/taxonomy/accession2taxid/" + "prot.accession2taxid.gz.md5" ], check=True ) - for ext in ["", ".md5"] ] - VIII_call = call(f"{proteins_path}.md5", proteins_path) # Check that commands are ran as expected - subp_run.assert_has_calls( - [I_call, II_call, IV_call, V_call, VI_call, VII_call], + mock_os_rm.assert_called_once_with(zip_path) + mock_run.assert_has_calls( + expected_calls, + any_order=False + ) + mock_md5.assert_has_calls( + [ + call(f"{zip_path}.md5", zip_path), + call(f"{proteins_path}.md5", proteins_path), + ], any_order=False ) - cc_md5.assert_has_calls([III_call, VIII_call], any_order=False) - @patch("subprocess.run") - def test_collect_and_compare_md5_valid(self, subp_run): + @patch("os.remove") + def test_collect_and_compare_md5_valid(self, mock_os_rm): path_to_file = self.get_data_path("md5/a.txt") # Should raise no errors _collect_and_compare_md5(f"{path_to_file}.md5", path_to_file) # Check rm is called as expected - subp_run.assert_called_once_with( - ["rm", f"{path_to_file}.md5"], check=True - ) + mock_os_rm.assert_called_once_with(f"{path_to_file}.md5") - @patch("subprocess.run") - def test_collect_and_compare_md5_invalid(self, subp_run): + @patch("os.remove") + def test_collect_and_compare_md5_invalid(self, mock_os_rm): path_to_file = self.get_data_path("md5/b.txt") path_to_wrong_md5 = self.get_data_path("md5/a.txt.md5") @@ -225,7 +239,7 @@ def test_collect_and_compare_md5_invalid(self, subp_run): _collect_and_compare_md5(path_to_wrong_md5, path_to_file) # check that rm is not called - subp_run.assert_not_called() + mock_os_rm.assert_not_called() @patch("q2_moshpit.eggnog._dbs._validate_taxon_id") @patch("subprocess.run") diff --git a/q2_moshpit/tests/test_utils.py b/q2_moshpit/tests/test_utils.py index 98a6858d..780a10c7 100644 --- a/q2_moshpit/tests/test_utils.py +++ b/q2_moshpit/tests/test_utils.py @@ -7,10 +7,9 @@ # ---------------------------------------------------------------------------- import unittest from qiime2.plugin.testing import TestPluginBase -from qiime2.core.exceptions import ValidationError from .._utils import ( - _construct_param, _process_common_input_params, compare_md5_hashes, - calculate_md5_from_file + _construct_param, _process_common_input_params, + _calculate_md5_from_file ) @@ -114,28 +113,14 @@ def test_process_common_inputs_mix_with_falsy_values(self): ] self.assertSetEqual(set(observed), set(expected)) - def test_compare_md5_hashes_pass(self): - path_to_file = self.get_data_path("md5/a.txt") - compare_md5_hashes("a583054a9831a6e7cc56ea5cd9cac40a", path_to_file) - - def test_compare_md5_hashes_fail(self): - path_to_file = self.get_data_path("md5/b.txt") - with self.assertRaisesRegex( - ValidationError, - "has an unexpected MD5 hash" - ): - compare_md5_hashes( - "a583054a9831a6e7cc56ea5cd9cac40a", path_to_file - ) - def test_calculate_md5_from_pass(self): path_to_file = self.get_data_path("md5/a.txt") - observed_hash = calculate_md5_from_file(path_to_file) + observed_hash = _calculate_md5_from_file(path_to_file) self.assertEqual(observed_hash, "a583054a9831a6e7cc56ea5cd9cac40a") def test_calculate_md5_from_fail(self): path_to_file = self.get_data_path("md5/b.txt") - observed_hash = calculate_md5_from_file(path_to_file) + observed_hash = _calculate_md5_from_file(path_to_file) self.assertNotEqual(observed_hash, "a583054a9831a6e7cc56ea5cd9cac40a") diff --git a/setup.py b/setup.py index b7cf194e..1f048bff 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ 'q2_moshpit': [ 'citations.bib', 'tests/data/*', + 'tests/data/md5/*', "assets/busco/*", "assets/busco/js/*", "assets/busco/css/*", @@ -47,6 +48,7 @@ ], 'q2_moshpit.eggnog': [ 'tests/data/*', + 'tests/data/md5/*', 'tests/data/build_eggnog_diamond_db/*', 'tests/data/contig-sequences-1/*', 'tests/data/mag-sequences/*',