diff --git a/pgscatalog_utils/config.py b/pgscatalog_utils/config.py index 60d1666..ab78e06 100644 --- a/pgscatalog_utils/config.py +++ b/pgscatalog_utils/config.py @@ -20,6 +20,16 @@ logger = logging.getLogger(__name__) +def headers() -> dict[str, str]: + if PGSC_CALC_VERSION is None: + raise Exception("Missing User-Agent when querying PGS Catalog") + else: + logger.info(f"User-Agent header: {PGSC_CALC_VERSION}") + + header = {"User-Agent": PGSC_CALC_VERSION} + return header + + def setup_tmpdir(outdir, combine=False): if combine: work_dir = "work_combine" diff --git a/pgscatalog_utils/download/Catalog.py b/pgscatalog_utils/download/Catalog.py index b09d00a..3397578 100644 --- a/pgscatalog_utils/download/Catalog.py +++ b/pgscatalog_utils/download/Catalog.py @@ -6,10 +6,10 @@ import requests -from pgscatalog_utils import __version__ as pgscatalog_utils_version from pgscatalog_utils import config from pgscatalog_utils.download.CatalogCategory import CatalogCategory from pgscatalog_utils.download.ScoringFile import ScoringFile +from pgscatalog_utils.download.download_file import get_with_user_agent logger = logging.getLogger(__name__) @@ -69,8 +69,8 @@ def get_download_urls(self) -> dict[str: ScoringFile]: case CatalogCategory.TRAIT | CatalogCategory.PUBLICATION: # publications and traits have to query Catalog API again to grab score data results: list[CatalogResult] = CatalogQuery(CatalogCategory.SCORE, - accession=list(self.pgs_ids), - pgsc_calc_version=config.PGSC_CALC_VERSION).get() + accession=list( + self.pgs_ids)).get() for result in results: for pgs in result.response.get("results"): urls[pgs["id"]] = ScoringFile.from_result(pgs) @@ -84,12 +84,9 @@ class CatalogQuery: """ category: CatalogCategory accession: typing.Union[str, list[str]] - pgsc_calc_version: typing.Union[str, None] include_children: bool = False _rest_url_root: str = "https://www.pgscatalog.org/rest" _max_retries: int = 5 - _version: str = pgscatalog_utils_version - _user_agent: dict[str: str] = field(init=False) def _resolve_query_url(self) -> typing.Union[str, list[str]]: child_flag: int = int(self.include_children) @@ -109,16 +106,8 @@ def _resolve_query_url(self) -> typing.Union[str, list[str]]: case CatalogCategory.PUBLICATION, str(): return f"{self._rest_url_root}/publication/{self.accession}" case _: - raise Exception(f"Invalid CatalogCategory and accession type: {self.category}, type({self.accession})") - - def __post_init__(self): - ua: str - if self.pgsc_calc_version: - ua = pgscatalog_utils_version - else: - ua = f"pgscatalog_utils/{self._version}" - - self._user_agent = {"User-Agent": ua} + raise Exception( + f"Invalid CatalogCategory and accession type: {self.category}, type({self.accession})") def _query_api(self, url: str): wait: int = 10 @@ -128,7 +117,7 @@ def _query_api(self, url: str): while retry < self._max_retries: try: logger.info(f"Querying {url}") - r: requests.models.Response = requests.get(url, headers=self._user_agent) + r: requests.models.Response = get_with_user_agent(url) r.raise_for_status() results_json = r.json() break diff --git a/pgscatalog_utils/download/ScoringFileChecksum.py b/pgscatalog_utils/download/ScoringFileChecksum.py index 1842076..06e29d6 100644 --- a/pgscatalog_utils/download/ScoringFileChecksum.py +++ b/pgscatalog_utils/download/ScoringFileChecksum.py @@ -15,7 +15,7 @@ def _generate_md5_checksum(filename: str, blocksize=4096) -> typing.Union[str, N """ Returns MD5 checksum for the given file. """ md5 = hashlib.md5() try: - file = open(filename, 'rb') + file = open(config.OUTDIR.joinpath(filename), 'rb') with file: for block in iter(lambda: file.read(blocksize), b""): md5.update(block) diff --git a/pgscatalog_utils/download/download_file.py b/pgscatalog_utils/download/download_file.py index 720b754..a7a638e 100644 --- a/pgscatalog_utils/download/download_file.py +++ b/pgscatalog_utils/download/download_file.py @@ -1,5 +1,4 @@ import logging -import os import pathlib import time import urllib.parse @@ -13,6 +12,10 @@ logger = logging.getLogger(__name__) +def get_with_user_agent(url: str) -> requests.Response: + return requests.get(url, headers=config.headers()) + + def download_file(url: str, local_path: str, overwrite: bool, ftp_fallback: bool) -> None: if config.OUTDIR.joinpath(local_path).exists(): if not overwrite: @@ -25,7 +28,7 @@ def download_file(url: str, local_path: str, overwrite: bool, ftp_fallback: bool attempt: int = 0 while attempt < config.MAX_RETRIES: - response: requests.Response = requests.get(url) + response: requests.Response = get_with_user_agent(url) match response.status_code: case 200: with open(config.OUTDIR.joinpath(local_path), "wb") as f: @@ -69,3 +72,4 @@ def _ftp_fallback_download(url: str, local_path: str) -> None: else: logger.critical(f"Download failed: {e}") raise Exception + diff --git a/pgscatalog_utils/download/download_scorefile.py b/pgscatalog_utils/download/download_scorefile.py index 1cc293f..4139fd3 100644 --- a/pgscatalog_utils/download/download_scorefile.py +++ b/pgscatalog_utils/download/download_scorefile.py @@ -1,17 +1,16 @@ import argparse import logging -import os import pathlib import textwrap import typing +from pgscatalog_utils import __version__ as version from pgscatalog_utils import config -from pgscatalog_utils.download.CatalogCategory import CatalogCategory from pgscatalog_utils.download.Catalog import CatalogQuery, CatalogResult +from pgscatalog_utils.download.CatalogCategory import CatalogCategory from pgscatalog_utils.download.GenomeBuild import GenomeBuild from pgscatalog_utils.download.ScoringFileDownloader import ScoringFileDownloader - logger = logging.getLogger(__name__) @@ -44,7 +43,11 @@ def download_scorefile() -> None: if args.pgsc_calc: config.PGSC_CALC_VERSION = args.pgsc_calc - logger.info(f"Setting user agent to {config.PGSC_CALC_VERSION} for PGS Catalog API queries") + logger.info( + f"Setting user agent to {config.PGSC_CALC_VERSION} for PGS Catalog API queries") + else: + config.PGSC_CALC_VERSION = f"pgscatalog_utils/{version}" + logger.warning(f"No user agent set, defaulting to {config.PGSC_CALC_VERSION}") config.OUTDIR = pathlib.Path(args.outdir).resolve() logger.info(f"Download directory: {config.OUTDIR}") @@ -60,19 +63,19 @@ def download_scorefile() -> None: else: logger.debug("--trait set, querying traits") for term in args.efo: - results.append(CatalogQuery(CatalogCategory.TRAIT, term, include_children=inc_child, - pgsc_calc_version=config.PGSC_CALC_VERSION).get()) + results.append(CatalogQuery(CatalogCategory.TRAIT, term, + include_children=inc_child).get()) if args.pgp: logger.debug("--pgp set, querying publications") for term in args.pgp: - results.append(CatalogQuery(CatalogCategory.PUBLICATION, term, pgsc_calc_version=config.PGSC_CALC_VERSION).get()) + results.append(CatalogQuery(CatalogCategory.PUBLICATION, term).get()) if args.pgs: logger.debug("--id set, querying scores") results.append( - CatalogQuery(CatalogCategory.SCORE, args.pgs, - pgsc_calc_version=config.PGSC_CALC_VERSION).get()) # pgs_lst: a list containing up to three flat lists + CatalogQuery(CatalogCategory.SCORE, + args.pgs).get()) # pgs_lst: a list containing up to three flat lists flat_results = [element for sublist in results for element in sublist] diff --git a/tests/test_combine.py b/tests/test_combine.py index 306564f..db92cc9 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -33,7 +33,7 @@ def test_fail_combine(scorefiles, tmp_path_factory): @pytest.fixture def _n_variants(pgs_accessions): - result = CatalogQuery(CatalogCategory.SCORE, accession=pgs_accessions, pgsc_calc_version=None).get()[0] + result = CatalogQuery(CatalogCategory.SCORE, accession=pgs_accessions).get()[0] json = result.response n: list[int] = jq.compile("[.results][][].variants_number").input(json).all() return sum(n) diff --git a/tests/test_download.py b/tests/test_download.py index 27e5a6a..3ea12a5 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,4 +1,5 @@ import gzip +import logging import os from unittest.mock import patch @@ -7,6 +8,22 @@ from pgscatalog_utils.download.download_scorefile import download_scorefile +def test_checksum_validation(tmp_path, caplog): + out_dir = str(tmp_path.resolve()) + pgs_id = 'PGS000001' + args: list[str] = ['download_scorefiles', '-i', pgs_id, '-b', 'GRCh38', '-o', + out_dir, '-v'] + + with patch('sys.argv', args): + caplog.set_level(logging.INFO) + # Test download + download_scorefile() + hm_score_filename = f'{pgs_id}_hmPOS_GRCh38.txt.gz' + assert hm_score_filename in os.listdir(out_dir) + # make sure validation passed + assert "Checksum matches" in [x.message for x in caplog.records] + + def test_download_scorefile_author(tmp_path): out_dir = str(tmp_path.resolve()) pgs_id = 'PGS000001' @@ -92,13 +109,11 @@ def test_download_trait(tmp_path): def test_query_publication(): # publications are relatively static - query: list[CatalogResult] = CatalogQuery(CatalogCategory.PUBLICATION, accession="PGP000001", - pgsc_calc_version=None).get() + query: list[CatalogResult] = CatalogQuery(CatalogCategory.PUBLICATION, accession="PGP000001").get() assert not query[0].pgs_ids.difference({'PGS000001', 'PGS000002', 'PGS000003'}) def test_query_trait(): # new scores may be added to traits in the future - query: list[CatalogResult] = CatalogQuery(CatalogCategory.TRAIT, accession="EFO_0004329", - pgsc_calc_version=None).get() + query: list[CatalogResult] = CatalogQuery(CatalogCategory.TRAIT, accession="EFO_0004329").get() assert not {'PGS001901', 'PGS002115'}.difference(query[0].pgs_ids)