Skip to content

Commit

Permalink
MAINT: add progress bar and informative prints to fetch-card-db act…
Browse files Browse the repository at this point in the history
…ion (#68)
  • Loading branch information
VinzentRisch authored Apr 30, 2024
1 parent 1a3f7cf commit 0ff2501
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 41 deletions.
1 change: 1 addition & 0 deletions ci/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ requirements:
- q2templates {{ qiime2_epoch }}.*
- q2cli {{ qiime2_epoch }}.*
- rgi
- tqdm

test:
requires:
Expand Down
79 changes: 60 additions & 19 deletions q2_amr/card/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,54 @@
import tempfile

import requests
from tqdm import tqdm

from q2_amr.card.utils import run_command
from q2_amr.card.utils import colorify, run_command
from q2_amr.types._format import (
CARDDatabaseDirectoryFormat,
CARDKmerDatabaseDirectoryFormat,
)


def fetch_card_db() -> (CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat):
# Fetch CARD and WildCARD data from CARD website
try:
response_card = requests.get(
"https://card.mcmaster.ca/latest/data", stream=True
)
response_wildcard = requests.get(
"https://card.mcmaster.ca/latest/variants", stream=True
)
except requests.ConnectionError as e:
raise requests.ConnectionError("Network connectivity problems.") from e

# Create temporary directory for WildCARD data
with tempfile.TemporaryDirectory() as tmp_dir:
try:
card_tar_path = os.path.join(tmp_dir, "card_tar")
wildcard_tar_path = os.path.join(tmp_dir, "wildcard_tar")

# Download CARD and WildCARD tar database archives with progressbars
download_with_progress_bar(
url="https://card.mcmaster.ca/latest/data",
description="Downloading CARD database",
tar_path=card_tar_path,
)

download_with_progress_bar(
url="https://card.mcmaster.ca/latest/variants",
description="Downloading WildCARD database",
tar_path=wildcard_tar_path,
)

except requests.ConnectionError as e:
raise requests.ConnectionError(
"Unable to connect to the CARD server. Please try again later."
) from e

print(colorify("Extracting database files..."), flush=True)

# Create directories to store zipped and unzipped database files
os.mkdir(os.path.join(tmp_dir, "card"))
os.mkdir(os.path.join(tmp_dir, "wildcard_zip"))
os.mkdir(os.path.join(tmp_dir, "wildcard"))

# Extract tar.bz2 archives and store files in dirs "card" and "wildcard_zip"
try:
with tarfile.open(
fileobj=response_card.raw, mode="r|bz2"
) as c_tar, tarfile.open(
fileobj=response_wildcard.raw, mode="r|bz2"
) as wc_tar:
with tarfile.open(card_tar_path, mode="r:bz2") as c_tar:
c_tar.extractall(path=os.path.join(tmp_dir, "card"))

with tarfile.open(wildcard_tar_path, mode="r|bz2") as wc_tar:
wc_tar.extractall(path=os.path.join(tmp_dir, "wildcard_zip"))

except tarfile.ReadError as a:
raise tarfile.ReadError("Tarfile is invalid.") from a

Expand All @@ -63,12 +78,16 @@ def fetch_card_db() -> (CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFo
) as f_out:
f_out.write(f_in.read())

print(colorify("Preprocessing database files..."), flush=True)

# Preprocess data for CARD and WildCARD
# This creates additional fasta files in the temp directory
preprocess(dir=tmp_dir, operation="card")
preprocess(dir=tmp_dir, operation="wildcard")

# Create CARD and Kmer database objects
print(colorify("Creating database artifacts..."), flush=True)

# Create CARD and Kmer database artifacts
card_db = CARDDatabaseDirectoryFormat()
kmer_db = CARDKmerDatabaseDirectoryFormat()

Expand Down Expand Up @@ -115,6 +134,28 @@ def fetch_card_db() -> (CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFo
return card_db, kmer_db


def download_with_progress_bar(url, description, tar_path):
response = requests.get(url=url, stream=True)

# Get content length to calculate progress bar length
tot_size = int(response.headers.get("content-length", 0))

# Initialize CARD progress bar and download database
progress_bar_card = tqdm(
total=tot_size,
unit="B",
unit_scale=True,
desc=description,
)

with open(tar_path, "wb") as file:
for chunk in response.iter_content(chunk_size=8192):
file.write(chunk) if chunk else False
if tot_size > 0:
progress_bar_card.update(len(chunk))
progress_bar_card.close() if tot_size > 0 else False


def preprocess(dir, operation):
if operation == "card":
# Run RGI command for CARD data
Expand Down
65 changes: 44 additions & 21 deletions q2_amr/card/tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import shutil
import subprocess
import tarfile
from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, patch

import requests
from qiime2.plugin.testing import TestPluginBase

from q2_amr.card.database import fetch_card_db, preprocess
from q2_amr.card.database import download_with_progress_bar, fetch_card_db, preprocess
from q2_amr.types import CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat


Expand All @@ -32,19 +32,16 @@ def mock_preprocess(self, dir, operation):

def test_fetch_card_db(self):
# Open dummy archives for CARD and WildCARD download
f_card = open(self.get_data_path("card.tar.bz2"), "rb")
f_wildcard = open(self.get_data_path("wildcard_data.tar.bz2"), "rb")

# Create MagicMock objects to simulate responses from requests.get
mock_response_card = MagicMock(raw=f_card)
mock_response_wildcard = MagicMock(raw=f_wildcard)
card_tar = self.get_data_path("card.tar.bz2")
wildcard_tar = self.get_data_path("wildcard_data.tar.bz2")

# Patch requests.get,
with patch("requests.get") as mock_requests, patch(
with patch("q2_amr.card.database.download_with_progress_bar"), patch(
"q2_amr.card.database.preprocess", side_effect=self.mock_preprocess
), patch(
"tarfile.open",
side_effect=[tarfile.open(card_tar), tarfile.open(wildcard_tar)],
):
# Assign MagicMock objects as side effects and run the function
mock_requests.side_effect = [mock_response_card, mock_response_wildcard]
obs = fetch_card_db()

# Lists of filenames contained in CARD and Kmer database objects
Expand All @@ -60,6 +57,7 @@ def test_fetch_card_db(self):
"card_database_v3.2.7_all.fasta",
"card.json",
]

files_kmer_db = ["all_amr_61mers.txt", "61_kmer_db.json"]

# Assert if all files are in the correct database object
Expand All @@ -73,26 +71,21 @@ def test_fetch_card_db(self):
self.assertIsInstance(obs[0], CARDDatabaseDirectoryFormat)
self.assertIsInstance(obs[1], CARDKmerDatabaseDirectoryFormat)

# Assert if requests.get gets called with the correct URLs
expected_calls = [
call("https://card.mcmaster.ca/latest/data", stream=True),
call("https://card.mcmaster.ca/latest/variants", stream=True),
]
mock_requests.assert_has_calls(expected_calls)

def test_connection_error(self):
# Simulate a ConnectionError during requests.get
with patch(
"requests.get", side_effect=requests.ConnectionError
"q2_amr.card.database.download_with_progress_bar",
side_effect=requests.ConnectionError,
), self.assertRaisesRegex(
requests.ConnectionError, "Network connectivity problems."
requests.ConnectionError,
"Unable to connect to the CARD server. " "Please try again later.",
):
fetch_card_db()

def test_tarfile_read_error(self):
# Simulate a tarfile.ReadError during tarfile.open
with patch("tarfile.open", side_effect=tarfile.ReadError), patch(
"requests.get"
"q2_amr.card.database.download_with_progress_bar"
), self.assertRaisesRegex(tarfile.ReadError, "Tarfile is invalid."):
fetch_card_db()

Expand Down Expand Up @@ -138,3 +131,33 @@ def test_preprocess_wildcard(self):
"path_tmp",
verbose=True,
)

def test_download_with_progressbar(self):
url = "http://example.com"
progressbar_desc = "Downloading"
tar_path = "/path/to/downloaded/file.tar"

with patch("requests.get") as mock_get, patch(
"q2_amr.card.database.tqdm"
) as mock_tqdm, patch("builtins.open") as mock_open:
# Mock response object
response_mock = MagicMock()
response_mock.headers = {"content-length": "1024"}
response_mock.iter_content.return_value = [b"data"]

# Patch the requests.get to return our response_mock
mock_get.return_value = response_mock

# Mock the open function to avoid creating files
mock_open.return_value.__enter__.return_value = MagicMock()

# Call the function
download_with_progress_bar(url, progressbar_desc, tar_path)

# Assertions
mock_get.assert_called_once_with(url=url, stream=True)
mock_tqdm.assert_called_once_with(
total=1024, unit="B", unit_scale=True, desc=progressbar_desc
)
mock_open.assert_called_once_with(tar_path, "wb")
response_mock.iter_content.assert_called_once_with(chunk_size=8192)
9 changes: 8 additions & 1 deletion q2_amr/card/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from qiime2.plugin.testing import TestPluginBase

from q2_amr.card.utils import create_count_table, load_card_db, read_in_txt
from q2_amr.card.utils import colorify, create_count_table, load_card_db, read_in_txt
from q2_amr.types import CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat


Expand Down Expand Up @@ -191,3 +191,10 @@ def test_create_count_table(self):
def test_create_count_table_value_error(self):
# Assert if ValueError is called when empy list is passed
self.assertRaises(ValueError, create_count_table, [])

def test_colorify(self):
# Test if colorify function correctly adds color codes
string = "Hello, world!"
colored_string = colorify(string)
expected_output = "\033[1;32mHello, world!\033[0m"
self.assertEqual(colored_string, expected_output)
4 changes: 4 additions & 0 deletions q2_amr/card/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,7 @@ def create_count_table(df_list: list) -> pd.DataFrame:
df.columns.name = None
df.index.name = "sample_id"
return df


def colorify(string: str):
return "%s%s%s" % ("\033[1;32m", string, "\033[0m")

0 comments on commit 0ff2501

Please sign in to comment.