diff --git a/.coveragerc b/.coveragerc index 2998fde..4776ee2 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,4 +1,5 @@ [run] +source=q2_amr branch = True omit = */tests* diff --git a/.gitignore b/.gitignore index 744fd2a..b85dd07 100644 --- a/.gitignore +++ b/.gitignore @@ -42,6 +42,7 @@ htmlcov/ .tox/ .nox/ .coverage +.coveragerc .coverage.* .cache nosetests.xml diff --git a/Makefile b/Makefile index a862ca3..6bebe23 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,8 @@ test: all py.test test-cov: all - py.test --cov=q2_amr + coverage run -m pytest + coverage xml install: all $(PYTHON) setup.py install diff --git a/ci/recipe/meta.yaml b/ci/recipe/meta.yaml index 01d1b7e..0fd7f77 100644 --- a/ci/recipe/meta.yaml +++ b/ci/recipe/meta.yaml @@ -19,6 +19,12 @@ requirements: run: - python {{ python }} - qiime2 {{ qiime2_epoch }}.* + - q2-types-genomics {{ qiime2_epoch }}.* + - q2-types {{ qiime2_epoch }}.* + - q2templates {{ qiime2_epoch }}.* + - q2cli {{ qiime2_epoch }}.* + - rgi + - altair test: requires: @@ -28,7 +34,7 @@ test: - q2_amr - qiime2.plugins.amr commands: - - pytest --cov q2_amr --pyargs q2_amr + - pytest --cov q2_amr --cov-report xml:coverage.xml --pyargs q2_amr about: home: https://github.com/bokulich-lab/q2-amr diff --git a/q2_amr/card/database.py b/q2_amr/card/database.py index c232d38..733b59d 100644 --- a/q2_amr/card/database.py +++ b/q2_amr/card/database.py @@ -1,28 +1,142 @@ +import glob +import gzip import os import shutil +import subprocess import tarfile import tempfile import requests -from q2_amr.types import CARDDatabaseDirectoryFormat +from q2_amr.card.utils import run_command +from q2_amr.types._format import ( + CARDDatabaseDirectoryFormat, + CARDKmerDatabaseDirectoryFormat, +) -CARD_URL = "https://card.mcmaster.ca/latest/data" - -def fetch_card_db() -> CARDDatabaseDirectoryFormat: +def fetch_card_db() -> (CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat): + # Fetch CARD and WildCARD data from CARD website try: - response = requests.get(CARD_URL, stream=True) + response_card = requests.get( + "https://card.mcmaster.ca/latest/data", stream=True + ) + response_wildcard = requests.get( + "https://card.mcmaster.ca/latest/variants", stream=True + ) except requests.ConnectionError as e: raise requests.ConnectionError("Network connectivity problems.") from e + + # Create temporary directory for WildCARD data with tempfile.TemporaryDirectory() as tmp_dir: + os.mkdir(os.path.join(tmp_dir, "wildcard")) + + # Extract tar.bz2 archives and store files in dirs "card" and "wildcard_zip" try: - with tarfile.open(fileobj=response.raw, mode="r|bz2") as tar: - tar.extractall(path=tmp_dir) + with tarfile.open( + fileobj=response_card.raw, mode="r|bz2" + ) as c_tar, tarfile.open( + fileobj=response_wildcard.raw, mode="r|bz2" + ) as wc_tar: + c_tar.extractall(path=os.path.join(tmp_dir, "card")) + wc_tar.extractall(path=os.path.join(tmp_dir, "wildcard_zip")) except tarfile.ReadError as a: raise tarfile.ReadError("Tarfile is invalid.") from a + + # List of files to be unzipped + files = ( + "index-for-model-sequences.txt.gz", + "nucleotide_fasta_protein_homolog_model_variants.fasta.gz", + "nucleotide_fasta_protein_overexpression_model_variants.fasta.gz", + "nucleotide_fasta_protein_variant_model_variants.fasta.gz", + "nucleotide_fasta_rRNA_gene_variant_model_variants.fasta.gz", + "61_kmer_db.json.gz", + "all_amr_61mers.txt.gz", + ) + + # Unzip gzip files and save them in "wildcard" dir + for file in files: + with gzip.open( + os.path.join(tmp_dir, "wildcard_zip", file), "rb" + ) as f_in, open( + os.path.join(tmp_dir, "wildcard", file[:-3]), "wb" + ) as f_out: + f_out.write(f_in.read()) + + # Preprocess data for CARD and WildCARD + # This creates additional fasta files in the temp directory + preprocess(dir=tmp_dir, operation="card") + preprocess(dir=tmp_dir, operation="wildcard") + + # Create CARD and Kmer database objects card_db = CARDDatabaseDirectoryFormat() - shutil.move( - os.path.join(tmp_dir, "card.json"), os.path.join(str(card_db), "card.json") + kmer_db = CARDKmerDatabaseDirectoryFormat() + + # Find names of CARD database files created by preprocess function + card_db_files = [ + os.path.basename(file) + for file in glob.glob(os.path.join(tmp_dir, "card_database_v*.fasta")) + ] + + # Lists of filenames to be moved to CARD and Kmer database objects + wildcard_to_card_db = [ + "index-for-model-sequences.txt", + "nucleotide_fasta_protein_homolog_model_variants.fasta", + "nucleotide_fasta_protein_overexpression_model_variants.fasta", + "nucleotide_fasta_protein_variant_model_variants.fasta", + "nucleotide_fasta_rRNA_gene_variant_model_variants.fasta", + ] + tmp_to_card_db = [ + "wildcard_database_v0.fasta", + "wildcard_database_v0_all.fasta", + card_db_files[0], + card_db_files[1], + ] + wildcard_to_kmer_db = ["all_amr_61mers.txt", "61_kmer_db.json"] + + # List of source and destination paths for files + src_des_list = [ + (os.path.join(tmp_dir, "card"), str(card_db)), + (os.path.join(tmp_dir, "wildcard"), str(card_db)), + (tmp_dir, str(card_db)), + (os.path.join(tmp_dir, "wildcard"), str(kmer_db)), + ] + + # Move all files from source path to destination path + for file_list, src_des in zip( + [["card.json"], wildcard_to_card_db, tmp_to_card_db, wildcard_to_kmer_db], + src_des_list, + ): + for file in file_list: + shutil.move( + os.path.join(src_des[0], file), os.path.join(src_des[1], file) + ) + + return card_db, kmer_db + + +def preprocess(dir, operation): + if operation == "card": + # Run RGI command for CARD data + cmd = ["rgi", "card_annotation", "-i", "card/card.json"] + elif operation == "wildcard": + # Run RGI command for WildCARD data + cmd = [ + "rgi", + "wildcard_annotation", + "-i", + "wildcard", + "--card_json", + "card/card.json", + "-v", + "0", + ] + + try: + run_command(cmd, dir, verbose=True) + except subprocess.CalledProcessError as e: + raise Exception( + f"An error was encountered while running rgi, " + f"(return code {e.returncode}), please inspect " + "stdout and stderr to learn more." ) - return card_db diff --git a/q2_amr/card/mags.py b/q2_amr/card/mags.py index 3198775..0e6cc40 100644 --- a/q2_amr/card/mags.py +++ b/q2_amr/card/mags.py @@ -6,18 +6,13 @@ import pandas as pd from q2_types_genomics.per_sample_data import MultiMAGSequencesDirFmt -from q2_amr.card.utils import ( - create_count_table, - load_preprocess_card_db, - read_in_txt, - run_command, -) -from q2_amr.types import CARDAnnotationDirectoryFormat, CARDDatabaseFormat +from q2_amr.card.utils import create_count_table, load_card_db, read_in_txt, run_command +from q2_amr.types import CARDAnnotationDirectoryFormat, CARDDatabaseDirectoryFormat def annotate_mags_card( mag: MultiMAGSequencesDirFmt, - card_db: CARDDatabaseFormat, + card_db: CARDDatabaseDirectoryFormat, alignment_tool: str = "BLAST", split_prodigal_jobs: bool = False, include_loose: bool = False, @@ -29,7 +24,7 @@ def annotate_mags_card( amr_annotations = CARDAnnotationDirectoryFormat() frequency_list = [] with tempfile.TemporaryDirectory() as tmp: - load_preprocess_card_db(tmp, card_db, "load") + load_card_db(tmp=tmp, card_db=card_db) for samp_bin in list(manifest.index): bin_dir = os.path.join(str(amr_annotations), samp_bin[0], samp_bin[1]) os.makedirs(bin_dir, exist_ok=True) diff --git a/q2_amr/card/reads.py b/q2_amr/card/reads.py index a1a5a65..724cbc0 100644 --- a/q2_amr/card/reads.py +++ b/q2_amr/card/reads.py @@ -14,15 +14,10 @@ SingleLanePerSampleSingleEndFastqDirFmt, ) -from q2_amr.card.utils import ( - create_count_table, - load_preprocess_card_db, - read_in_txt, - run_command, -) +from q2_amr.card.utils import create_count_table, load_card_db, read_in_txt, run_command from q2_amr.types import ( CARDAlleleAnnotationDirectoryFormat, - CARDDatabaseFormat, + CARDDatabaseDirectoryFormat, CARDGeneAnnotationDirectoryFormat, ) @@ -31,9 +26,11 @@ def annotate_reads_card( reads: Union[ SingleLanePerSamplePairedEndFastqDirFmt, SingleLanePerSampleSingleEndFastqDirFmt ], - card_db: CARDDatabaseFormat, + card_db: CARDDatabaseDirectoryFormat, aligner: str = "kma", threads: int = 1, + include_wildcard: bool = False, + include_other_models: bool = False, ) -> ( CARDAlleleAnnotationDirectoryFormat, CARDGeneAnnotationDirectoryFormat, @@ -46,9 +43,13 @@ def annotate_reads_card( amr_allele_annotation = CARDAlleleAnnotationDirectoryFormat() amr_gene_annotation = CARDGeneAnnotationDirectoryFormat() with tempfile.TemporaryDirectory() as tmp: - load_preprocess_card_db(tmp, card_db, "load") - load_preprocess_card_db(tmp, card_db, "preprocess") - load_preprocess_card_db(tmp, card_db, "load_fasta") + load_card_db( + tmp=tmp, + card_db=card_db, + fasta=True, + include_other_models=include_other_models, + include_wildcard=include_wildcard, + ) for samp in list(manifest.index): fwd = manifest.loc[samp, "forward"] rev = manifest.loc[samp, "reverse"] if paired else None @@ -65,6 +66,8 @@ def annotate_reads_card( rev=rev, aligner=aligner, threads=threads, + include_wildcard=include_wildcard, + include_other_models=include_other_models, ) path_allele = os.path.join(samp_input_dir, "output.allele_mapping_data.txt") allele_frequency = read_in_txt( @@ -109,6 +112,8 @@ def run_rgi_bwt( rev: str, aligner: str, threads: int, + include_wildcard: bool, + include_other_models: bool, ): cmd = [ "rgi", @@ -126,6 +131,10 @@ def run_rgi_bwt( ] if rev: cmd.extend(["--read_two", rev]) + if include_wildcard: + cmd.append("--include_wildcard") + if include_other_models: + cmd.append("--include_other_models") try: run_command(cmd, cwd, verbose=True) except subprocess.CalledProcessError as e: diff --git a/q2_amr/card/utils.py b/q2_amr/card/utils.py index 99415d5..36e4b75 100644 --- a/q2_amr/card/utils.py +++ b/q2_amr/card/utils.py @@ -1,4 +1,6 @@ +import glob import json +import os import subprocess from functools import reduce @@ -21,24 +23,64 @@ def run_command(cmd, cwd, verbose=True): subprocess.run(cmd, check=True, cwd=cwd) -def load_preprocess_card_db(tmp, card_db, operation): - if operation == "load": - cmd = ["rgi", "load", "--card_json", str(card_db), "--local"] - elif operation == "preprocess": - cmd = ["rgi", "card_annotation", "-i", str(card_db)] - elif operation == "load_fasta": - with open(str(card_db)) as f: +def load_card_db( + tmp, + card_db, + kmer_db=None, + kmer: bool = False, + fasta: bool = False, + include_other_models: bool = False, + include_wildcard: bool = False, +): + # Get path to card.json + path_card_json = os.path.join(str(card_db), "card.json") + + # Base command that only loads card.json into the local database + cmd = ["rgi", "load", "--card_json", path_card_json, "--local"] + + # Define suffixes for card fasta file + models = ("_all", "_all_models") if include_other_models is True else ("", "") + + # Extend base command with flag to load card fasta file + if fasta: + # Retrieve the database version number from card.jason file + with open(path_card_json) as f: card_data = json.load(f) version = card_data["_version"] - cmd = [ - "rgi", - "load", - "-i", - str(card_db), - "--card_annotation", - f"card_database_v{version}.fasta", - "--local", - ] + + # Define path to card fasta file + path_card_fasta = os.path.join( + str(card_db), f"card_database_v{version}{models[0]}.fasta" + ) + + # Extend base command + cmd.extend([f"--card_annotation{models[1]}", path_card_fasta]) + + # Extend base command with flag to load wildcard fasta file and index + if include_wildcard: + cmd.extend( + [ + f"--wildcard_annotation{models[1]}", + os.path.join(str(card_db), f"wildcard_database_v0{models[0]}.fasta"), + "--wildcard_index", + os.path.join(str(card_db), "index-for-model-sequences.txt"), + ] + ) + # Extend base command with flag to load kmer json and txt database files + if kmer: + path_kmer_json = glob.glob(os.path.join(str(kmer_db), "*_kmer_db.json"))[0] + cmd.extend( + [ + "--kmer_database", + path_kmer_json, + "--amr_kmers", + glob.glob(os.path.join(str(kmer_db), "all_amr_*mers.txt"))[0], + "--kmer_size", + os.path.basename(path_kmer_json).split("_")[0], + ] + ) + + # Run command try: run_command(cmd, tmp, verbose=True) except subprocess.CalledProcessError as e: diff --git a/q2_amr/plugin_setup.py b/q2_amr/plugin_setup.py index 84efb88..408f727 100644 --- a/q2_amr/plugin_setup.py +++ b/q2_amr/plugin_setup.py @@ -36,8 +36,18 @@ CARDAnnotationStatsFormat, CARDGeneAnnotationDirectoryFormat, CARDGeneAnnotationFormat, + CARDKmerDatabaseDirectoryFormat, + CARDKmerJSONFormat, + CARDKmerTXTFormat, + CARDWildcardIndexFormat, + GapDNAFASTAFormat, +) +from q2_amr.types._type import ( + CARDAlleleAnnotation, + CARDAnnotation, + CARDGeneAnnotation, + CARDKmerDatabase, ) -from q2_amr.types._type import CARDAlleleAnnotation, CARDAnnotation, CARDGeneAnnotation citations = Citations.load("citations.bib", package="q2_amr") @@ -55,15 +65,19 @@ function=fetch_card_db, inputs={}, parameters={}, - outputs=[("card_db", CARDDatabase)], + outputs=[("card_db", CARDDatabase), ("kmer_db", CARDKmerDatabase)], input_descriptions={}, parameter_descriptions={}, output_descriptions={ - "card_db": "CARD database of resistance genes, their products and associated " - "phenotypes." + "card_db": "CARD and WildCARD database of resistance genes, their products and " + "associated phenotypes.", + "kmer_db": "Database of k-mers that are uniquely found within AMR alleles of " + "individual pathogen species, pathogen genera, pathogen-restricted " + "plasmids, or promiscuous plasmids. The default k-mer length is 61 " + "bp, but users can create k-mers of any length.", }, - name="Download CARD data.", - description=("Download the latest version of the CARD database."), + name="Download CARD and WildCARD data.", + description="Download the latest version of the CARD and WildCARD databases.", citations=[citations["alcock_card_2023"]], ) @@ -90,7 +104,7 @@ "alignment_tool": "Specify alignment tool BLAST or DIAMOND.", "split_prodigal_jobs": "Run multiple prodigal jobs simultaneously for contigs" " in one sample", - "include_loose": "Include loose hits in addition to strict and perfect hits .", + "include_loose": "Include loose hits in addition to strict and perfect hits.", "include_nudge": "Include hits nudged from loose to strict hits.", "low_quality": "Use for short contigs to predict partial genes.", "threads": "Number of threads (CPUs) to use in the BLAST search.", @@ -104,7 +118,6 @@ citations=[citations["alcock_card_2023"]], ) - plugin.methods.register_function( function=annotate_reads_card, inputs={ @@ -114,6 +127,8 @@ parameters={ "aligner": Str % Choices(["kma", "bowtie2", "bwa"]), "threads": Int % Range(0, None, inclusive_start=False), + "include_wildcard": Bool, + "include_other_models": Bool, }, outputs=[ ("amr_allele_annotation", SampleData[CARDAlleleAnnotation]), @@ -128,6 +143,21 @@ parameter_descriptions={ "aligner": "Specify alignment tool.", "threads": "Number of threads (CPUs) to use.", + "include_wildcard": "Additionally align reads to the in silico predicted " + "allelic variants available in CARD's Resistomes & Variants" + " data set. This is highly recommended for non-clinical " + "samples .", + "include_other_models": "The default settings for will align reads against " + "CARD's protein homolog models. With include_other_" + "models set to True reads are additionally aligned to " + "protein variant models, rRNA mutation models, and " + "protein over-expression models. These three model " + "types require comparison to CARD's curated lists of " + "mutations known to confer phenotypic antibiotic " + "resistance to differentiate alleles conferring " + "resistance from antibiotic susceptible alleles, " + "but RGI as of yet does not perform this comparison. " + "Use these results with caution.", }, output_descriptions={ "amr_allele_annotation": "AMR annotation mapped on alleles.", @@ -179,9 +209,16 @@ # Registrations plugin.register_semantic_types( - CARDDatabase, CARDAnnotation, CARDAlleleAnnotation, CARDGeneAnnotation + CARDDatabase, + CARDKmerDatabase, + CARDAnnotation, + CARDAlleleAnnotation, + CARDGeneAnnotation, ) +plugin.register_semantic_type_to_format( + CARDKmerDatabase, artifact_format=CARDKmerDatabaseDirectoryFormat +) plugin.register_semantic_type_to_format( CARDDatabase, artifact_format=CARDDatabaseDirectoryFormat ) @@ -197,6 +234,11 @@ ) plugin.register_formats( + CARDKmerDatabaseDirectoryFormat, + CARDKmerJSONFormat, + CARDKmerTXTFormat, + GapDNAFASTAFormat, + CARDWildcardIndexFormat, CARDAnnotationTXTFormat, CARDAnnotationJSONFormat, CARDAnnotationDirectoryFormat, diff --git a/q2_amr/tests/card/test_database.py b/q2_amr/tests/card/test_database.py index e4ec9b7..eb9f4a7 100644 --- a/q2_amr/tests/card/test_database.py +++ b/q2_amr/tests/card/test_database.py @@ -1,30 +1,87 @@ import os +import shutil +import subprocess import tarfile -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, call, patch import requests from qiime2.plugin.testing import TestPluginBase -from q2_amr.card.database import fetch_card_db -from q2_amr.types import CARDDatabaseDirectoryFormat +from q2_amr.card.database import fetch_card_db, preprocess +from q2_amr.types import CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat class TestAnnotateMagsCard(TestPluginBase): package = "q2_amr.tests" + def mock_preprocess(self, dir, operation): + if operation == "card": + src_des = [ + ("DNA_fasta.fasta", "card_database_v3.2.7.fasta"), + ("DNA_fasta_-.fasta", "card_database_v3.2.7_all.fasta"), + ] + else: + src_des = [ + ("DNA_fasta.fasta", "wildcard_database_v0.fasta"), + ("DNA_fasta_-.fasta", "wildcard_database_v0_all.fasta"), + ] + for file_name_src, file_name_des in src_des: + shutil.copy( + self.get_data_path(file_name_src), os.path.join(dir, file_name_des) + ) + def test_fetch_card_db(self): - f = open(self.get_data_path("card.tar.bz2"), "rb") - mock_response = MagicMock(raw=f) - with patch("requests.get") as mock_requests: - mock_requests.return_value = mock_response + # Open dummy archives for CARD and WildCARD download + f_card = open(self.get_data_path("card.tar.bz2"), "rb") + f_wildcard = open(self.get_data_path("wildcard_data.tar.bz2"), "rb") + + # Create MagicMock objects to simulate responses from requests.get + mock_response_card = MagicMock(raw=f_card) + mock_response_wildcard = MagicMock(raw=f_wildcard) + + # Patch requests.get, + with patch("requests.get") as mock_requests, patch( + "q2_amr.card.database.preprocess", side_effect=self.mock_preprocess + ): + # Assign MagicMock objects as side effects and run the function + mock_requests.side_effect = [mock_response_card, mock_response_wildcard] obs = fetch_card_db() - self.assertTrue(os.path.exists(os.path.join(str(obs), "card.json"))) - self.assertIsInstance(obs, CARDDatabaseDirectoryFormat) - mock_requests.assert_called_once_with( - "https://card.mcmaster.ca/latest/data", stream=True - ) - def test_fetch_card_data_connection_error(self): + # Lists of filenames contained in CARD and Kmer database objects + files_card_db = [ + "index-for-model-sequences.txt", + "nucleotide_fasta_protein_homolog_model_variants.fasta", + "nucleotide_fasta_protein_overexpression_model_variants.fasta", + "nucleotide_fasta_protein_variant_model_variants.fasta", + "nucleotide_fasta_rRNA_gene_variant_model_variants.fasta", + "wildcard_database_v0.fasta", + "wildcard_database_v0_all.fasta", + "card_database_v3.2.7.fasta", + "card_database_v3.2.7_all.fasta", + "card.json", + ] + files_kmer_db = ["all_amr_61mers.txt", "61_kmer_db.json"] + + # Assert if all files are in the correct database object + for file_list, db_obj in zip( + [files_card_db, files_kmer_db], [str(obs[0]), str(obs[1])] + ): + for file in file_list: + self.assertTrue(os.path.exists(os.path.join(db_obj, file))) + + # Assert if both database objects have the correct format + self.assertIsInstance(obs[0], CARDDatabaseDirectoryFormat) + self.assertIsInstance(obs[1], CARDKmerDatabaseDirectoryFormat) + + # Assert if requests.get gets called with the correct URLs + expected_calls = [ + call("https://card.mcmaster.ca/latest/data", stream=True), + call("https://card.mcmaster.ca/latest/variants", stream=True), + ] + mock_requests.assert_has_calls(expected_calls) + + def test_connection_error(self): + # Simulate a ConnectionError during requests.get with patch( "requests.get", side_effect=requests.ConnectionError ), self.assertRaisesRegex( @@ -32,8 +89,52 @@ def test_fetch_card_data_connection_error(self): ): fetch_card_db() - def test_fetch_card_data_tarfile_read_error(self): - with patch( - "tarfile.open", side_effect=tarfile.ReadError + def test_tarfile_read_error(self): + # Simulate a tarfile.ReadError during tarfile.open + with patch("tarfile.open", side_effect=tarfile.ReadError), patch( + "requests.get" ), self.assertRaisesRegex(tarfile.ReadError, "Tarfile is invalid."): fetch_card_db() + + def test_subprocess_error(self): + # Simulate a subprocess.CalledProcessError during run_command + with patch( + "q2_amr.card.database.run_command", + side_effect=subprocess.CalledProcessError(1, "cmd"), + ), self.assertRaisesRegex( + Exception, + "An error was encountered while running rgi, " + r"\(return code 1\), please inspect stdout and stderr to learn more.", + ): + preprocess("path", "card") + + def test_preprocess_card(self): + # Ensure preprocess calls run_command with the correct arguments for "card" + # operation + with patch("q2_amr.card.database.run_command") as mock_run_command: + preprocess("path_tmp", "card") + mock_run_command.assert_called_once_with( + ["rgi", "card_annotation", "-i", "card/card.json"], + "path_tmp", + verbose=True, + ) + + def test_preprocess_wildcard(self): + # Ensure preprocess calls run_command with the correct arguments for "wildcard" + # operation + with patch("q2_amr.card.database.run_command") as mock_run_command: + preprocess("path_tmp", "wildcard") + mock_run_command.assert_called_once_with( + [ + "rgi", + "wildcard_annotation", + "-i", + "wildcard", + "--card_json", + "card/card.json", + "-v", + "0", + ], + "path_tmp", + verbose=True, + ) diff --git a/q2_amr/tests/card/test_mags.py b/q2_amr/tests/card/test_mags.py index 8573bda..9096137 100644 --- a/q2_amr/tests/card/test_mags.py +++ b/q2_amr/tests/card/test_mags.py @@ -60,7 +60,7 @@ def test_annotate_mags_card(self): mock_read_in_txt = MagicMock() with patch( "q2_amr.card.mags.run_rgi_main", side_effect=self.mock_run_rgi_main - ), patch("q2_amr.card.mags.load_preprocess_card_db"), patch( + ), patch("q2_amr.card.mags.load_card_db"), patch( "q2_amr.card.mags.read_in_txt", mock_read_in_txt ), patch( "q2_amr.card.mags.create_count_table", mock_create_count_table diff --git a/q2_amr/tests/card/test_reads.py b/q2_amr/tests/card/test_reads.py index 1685454..2b3a107 100644 --- a/q2_amr/tests/card/test_reads.py +++ b/q2_amr/tests/card/test_reads.py @@ -22,7 +22,7 @@ ) from q2_amr.types import ( CARDAlleleAnnotationDirectoryFormat, - CARDDatabaseFormat, + CARDDatabaseDirectoryFormat, CARDGeneAnnotationDirectoryFormat, ) @@ -30,6 +30,21 @@ class TestAnnotateReadsCARD(TestPluginBase): package = "q2_amr.tests" + @classmethod + def setUpClass(cls): + cls.sample_stats = { + "sample1": { + "total_reads": 5000, + "mapped_reads": 59, + "percentage": 1.18, + }, + "sample2": { + "total_reads": 7000, + "mapped_reads": 212, + "percentage": 3.03, + }, + } + def test_annotate_reads_card_single(self): self.annotate_reads_card_test_body("single") @@ -37,112 +52,106 @@ def test_annotate_reads_card_paired(self): self.annotate_reads_card_test_body("paired") def copy_needed_files(self, cwd, samp, **kwargs): - output_allele = self.get_data_path("output.allele_mapping_data.txt") - output_gene = self.get_data_path("output.gene_mapping_data.txt") - output_stats = self.get_data_path("output.overall_mapping_stats.txt") + # Create a sample directory samp_dir = os.path.join(cwd, samp) - shutil.copy(output_allele, samp_dir) - shutil.copy(output_gene, samp_dir) - shutil.copy(output_stats, samp_dir) + + # Copy three dummy files to the directory + for a, b in zip(["allele", "gene", "overall"], ["data", "data", "stats"]): + shutil.copy(self.get_data_path(f"output.{a}_mapping_{b}.txt"), samp_dir) def annotate_reads_card_test_body(self, read_type): + # Create single end or paired end reads object and CARD database object + reads = ( + SingleLanePerSampleSingleEndFastqDirFmt() + if read_type == "single" + else SingleLanePerSamplePairedEndFastqDirFmt() + ) + card_db = CARDDatabaseDirectoryFormat() + + # Copy manifest file to reads object manifest = self.get_data_path(f"MANIFEST_reads_{read_type}") - if read_type == "single": - reads = SingleLanePerSampleSingleEndFastqDirFmt() - shutil.copy(manifest, os.path.join(str(reads), "MANIFEST")) - else: - reads = SingleLanePerSamplePairedEndFastqDirFmt() - shutil.copy(manifest, os.path.join(str(reads), "MANIFEST")) - card_db = CARDDatabaseFormat() + shutil.copy(manifest, os.path.join(str(reads), "MANIFEST")) + + # Create MagicMock objects for run_rgi_bwt, run_rgi_load, read_in_txt and + # create_count_table functions mock_run_rgi_bwt = MagicMock(side_effect=self.copy_needed_files) mock_run_rgi_load = MagicMock() mock_read_in_txt = MagicMock() - mag_test_class = TestAnnotateMagsCard() mock_create_count_table = MagicMock( - side_effect=mag_test_class.return_count_table + side_effect=TestAnnotateMagsCard().return_count_table ) + + # Patch run_rgi_bwt, run_rgi_load, read_in_txt and create_count_table functions + # and assign MagicMock objects with patch("q2_amr.card.reads.run_rgi_bwt", mock_run_rgi_bwt), patch( - "q2_amr.card.reads.load_preprocess_card_db", mock_run_rgi_load + "q2_amr.card.reads.load_card_db", mock_run_rgi_load ), patch("q2_amr.card.reads.read_in_txt", mock_read_in_txt), patch( "q2_amr.card.reads.create_count_table", mock_create_count_table ): + + # Run annotate_reads_card function result = annotate_reads_card(reads, card_db) + + # Retrieve the path to cwd directory from mock_run_rgi_bwt arguments first_call_args = mock_run_rgi_bwt.call_args_list[0] tmp_dir = first_call_args.kwargs["cwd"] - if read_type == "single": - exp_calls_mock_run = [ - call( - cwd=tmp_dir, - samp="sample1", - fwd=f"{reads}/sample1_00_L001_R1_001.fastq.gz", - aligner="kma", - rev=None, - threads=1, - ), - call( - cwd=tmp_dir, - samp="sample2", - fwd=f"{reads}/sample2_00_L001_R1_001.fastq.gz", - aligner="kma", - rev=None, - threads=1, - ), - ] - else: - exp_calls_mock_run = [ - call( - cwd=tmp_dir, - samp="sample1", - fwd=f"{reads}/sample1_00_L001_R1_001.fastq.gz", - rev=f"{reads}/sample1_00_L001_R2_001.fastq.gz", - aligner="kma", - threads=1, - ), - call( - cwd=tmp_dir, - samp="sample2", - fwd=f"{reads}/sample2_00_L001_R1_001.fastq.gz", - rev=f"{reads}/sample2_00_L001_R2_001.fastq.gz", - aligner="kma", - threads=1, - ), - ] - exp_calls_mock_load = [ - call(tmp_dir, ANY, "load"), - call(tmp_dir, ANY, "preprocess"), - call(tmp_dir, ANY, "load_fasta"), - ] - exp_calls_mock_read = [ + + # Create four expected call objects for mock_run_rgi_bwt + exp_calls_mock_bwt = [ call( - path=f"{tmp_dir}/sample1/output.allele_mapping_data.txt", - col_name="ARO Accession", - samp_bin_name="sample1", - ), + cwd=tmp_dir, + aligner="kma", + threads=1, + include_wildcard=False, + include_other_models=False, + samp=f"sample{i}", + fwd=f"{reads}/sample{i}_00_L001_R1_001.fastq.gz", + rev=None + if read_type == "single" + else f"{reads}/sample{i}_00_L001_R2_001.fastq.gz", + ) + for i in range(1, 3) + ] + + # Expected call object for mock_run_rgi_load + exp_calls_mock_load = [ call( - path=f"{tmp_dir}/sample1/output.gene_mapping_data.txt", - col_name="ARO Accession", - samp_bin_name="sample1", - ), - call( - path=f"{tmp_dir}/sample2/output.allele_mapping_data.txt", - col_name="ARO Accession", - samp_bin_name="sample2", + tmp=tmp_dir, + card_db=ANY, + fasta=True, + include_other_models=False, + include_wildcard=False, ), + ] + + # Create four expected call objects for mock_read_in_txt + exp_calls_mock_read = [ call( - path=f"{tmp_dir}/sample2/output.gene_mapping_data.txt", + path=f"{tmp_dir}/{samp}/output.{model}_mapping_data.txt", col_name="ARO Accession", - samp_bin_name="sample2", - ), + samp_bin_name=samp, + ) + for samp in ["sample1", "sample2"] + for model in ["allele", "gene"] ] + + # Expected call objects for mock_create_count_table exp_calls_mock_count = [call([ANY, ANY]), call([ANY, ANY])] - mock_run_rgi_bwt.assert_has_calls(exp_calls_mock_run) + + # Assert if all patched function were called with the expected calls + mock_run_rgi_bwt.assert_has_calls(exp_calls_mock_bwt) mock_run_rgi_load.assert_has_calls(exp_calls_mock_load) mock_read_in_txt.assert_has_calls(exp_calls_mock_read) mock_create_count_table.assert_has_calls(exp_calls_mock_count) + + # Assert if all output files are the expected format self.assertIsInstance(result[0], CARDAlleleAnnotationDirectoryFormat) self.assertIsInstance(result[1], CARDGeneAnnotationDirectoryFormat) self.assertIsInstance(result[2], pd.DataFrame) self.assertIsInstance(result[3], pd.DataFrame) + + # Assert if the expected files are in every sample directory and in both + # resulting CARD annotation objects for num in [0, 1]: map_type = "allele" if num == 0 else "gene" for samp in ["sample1", "sample2"]: @@ -163,6 +172,8 @@ def test_run_rgi_bwt(self): "path_rev", "bowtie2", 8, + True, + True, ) mock_run_command.assert_called_once_with( [ @@ -180,6 +191,8 @@ def test_run_rgi_bwt(self): "bowtie2", "--read_two", "path_rev", + "--include_wildcard", + "--include_other_models", ], "path_tmp", verbose=True, @@ -195,14 +208,7 @@ def test_exception_raised(self): "q2_amr.card.reads.run_command" ) as mock_run_command, self.assertRaises(Exception) as cm: mock_run_command.side_effect = subprocess.CalledProcessError(1, "cmd") - run_rgi_bwt( - cwd="path/cwd", - samp="sample1", - fwd="path/fwd", - rev="path/rev", - aligner="bwa", - threads=1, - ) + run_rgi_bwt() self.assertEqual(str(cm.exception), expected_message) def test_move_files_allele(self): @@ -248,45 +254,37 @@ def test_extract_sample_stats(self): new_mapping_stats_path = os.path.join(tmp, "overall_mapping_stats.txt") shutil.copy(mapping_stats_path, new_mapping_stats_path) sample_stats = extract_sample_stats(tmp) - expected_result = { - "total_reads": 5000, - "mapped_reads": 59, - "percentage": 1.18, - } - self.assertEqual(sample_stats, expected_result) + + self.assertEqual(sample_stats, self.sample_stats["sample1"]) def test_plot_sample_stats(self): with tempfile.TemporaryDirectory() as tmp: - sample_stats = { - "sample1": { - "total_reads": 5000, - "mapped_reads": 106, - "percentage": 2.12, - }, - "sample2": { - "total_reads": 7000, - "mapped_reads": 212, - "percentage": 3.03, - }, - } - plot_sample_stats(sample_stats, tmp) + plot_sample_stats(self.sample_stats, tmp) self.assertTrue(os.path.exists(os.path.join(tmp, "sample_stats_plot.html"))) - def test_visualize_annotation_stats(self): - def mock_plot_sample_stats(sample_stats, output_dir): - with open(os.path.join(tmp, "sample_stats_plot.html"), "w") as file: - file.write("file") + def mock_plot_sample_stats(self, sample_stats, output_dir): + # Create a dummy HTML file and copy it to the output_dir + with open(os.path.join(output_dir, "sample_stats_plot.html"), "w") as file: + file.write("file") + def test_visualize_annotation_stats(self): + # Create a CARDGeneAnnotation object amr_reads_annotation = CARDGeneAnnotationDirectoryFormat() - sample1_dir = os.path.join(str(amr_reads_annotation), "sample1") - sample2_dir = os.path.join(str(amr_reads_annotation), "sample2") - os.makedirs(sample1_dir) - os.makedirs(sample2_dir) + + # Create two sample directories in the CARDGeneAnnotation object + for num in range(1, 3): + os.makedirs(os.path.join(str(amr_reads_annotation), f"sample{num}")) + + # Patch extract_sample_stats and plot_sample_stats with side effect + # mock_plot_sample_stats with patch("q2_amr.card.reads.extract_sample_stats"), patch( "q2_amr.card.reads.plot_sample_stats", - side_effect=mock_plot_sample_stats, + side_effect=self.mock_plot_sample_stats, ), tempfile.TemporaryDirectory() as tmp: + + # Run visualize_annotation_stats function visualize_annotation_stats(tmp, amr_reads_annotation) - self.assertTrue(os.path.exists(os.path.join(tmp, "sample_stats_plot.html"))) - self.assertTrue(os.path.exists(os.path.join(tmp, "index.html"))) - self.assertTrue(os.path.exists(os.path.join(tmp, "q2templateassets"))) + + # Assert if all expected files are created + for file in ["sample_stats_plot.html", "index.html", "q2templateassets"]: + self.assertTrue(os.path.exists(os.path.join(tmp, file))) diff --git a/q2_amr/tests/card/test_utils.py b/q2_amr/tests/card/test_utils.py index 10d3bc4..a6a6a9c 100644 --- a/q2_amr/tests/card/test_utils.py +++ b/q2_amr/tests/card/test_utils.py @@ -1,87 +1,121 @@ +import os +import shutil import subprocess -from unittest.mock import patch +from unittest.mock import call, patch import pandas as pd from qiime2.plugin.testing import TestPluginBase from test_mags import TestAnnotateMagsCard -from q2_amr.card.utils import create_count_table, load_preprocess_card_db, read_in_txt -from q2_amr.types import CARDDatabaseFormat +from q2_amr.card.utils import create_count_table, load_card_db, read_in_txt +from q2_amr.types import CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat class TestAnnotateReadsCARD(TestPluginBase): package = "q2_amr.tests" - mapping_data_sample1 = pd.DataFrame( - { - "ARO Accession": [3000796, 3000815, 3000805, 3000026], - "sample1": [1, 1, 1, 1], - } - ) - - mags_mapping_data_sample1 = pd.DataFrame( - { - "ARO": [3000796, 3000815, 3000805, 3000026], - "sample1": [1, 1, 1, 1], - } - ) - - mapping_data_sample2 = pd.DataFrame( - { - "ARO Accession": [3000797, 3000815, 3000805, 3000026], - "sample2": [1, 1, 1, 2], - } - ) - - def test_load_card_db(self): - card_db = CARDDatabaseFormat() - with patch("q2_amr.card.utils.run_command") as mock_run_command: - load_preprocess_card_db("path_tmp", card_db, "load") - mock_run_command.assert_called_once_with( - ["rgi", "load", "--card_json", str(card_db), "--local"], - "path_tmp", - verbose=True, - ) - - def test_preprocess_card_db(self): - card_db = CARDDatabaseFormat() - with patch("q2_amr.card.utils.run_command") as mock_run_command: - load_preprocess_card_db("path_tmp", card_db, "preprocess") - mock_run_command.assert_called_once_with( - ["rgi", "card_annotation", "-i", str(card_db)], "path_tmp", verbose=True - ) + @classmethod + def setUpClass(cls): + cls.mapping_data_sample1 = pd.DataFrame( + { + "ARO Accession": [3000796, 3000815, 3000805, 3000026], + "sample1": [1, 1, 1, 1], + } + ) + + cls.mapping_data_sample2 = pd.DataFrame( + { + "ARO Accession": [3000797, 3000815, 3000805, 3000026], + "sample2": [1, 1, 1, 2], + } + ) + + cls.mags_mapping_data_sample1 = pd.DataFrame( + { + "ARO": [3000796, 3000815, 3000805, 3000026], + "sample1": [1, 1, 1, 1], + } + ) def test_load_card_db_fasta(self): - card_db = self.get_data_path("card_test.json") + # Create CARD and Kmer database objects + card_db = CARDDatabaseDirectoryFormat() + kmer_db = CARDKmerDatabaseDirectoryFormat() + + # Tuples with source file name, destination file name and destination directory + src_des_dir = [ + ("card_test.json", "card.json", card_db), + ("kmer_txt_test.txt", "all_amr_61mers.txt", kmer_db), + ("kmer_json_test.json", "61_kmer_db.json", kmer_db), + ] + + # Copy files in src_des_dir to CARD and Kmer database objects + for src, des, dir in src_des_dir: + shutil.copy(self.get_data_path(src), os.path.join(str(dir), des)) + + # Patch run_command with patch("q2_amr.card.utils.run_command") as mock_run_command: - load_preprocess_card_db("path_tmp", card_db, "load_fasta") - mock_run_command.assert_called_once_with( - [ - "rgi", - "load", - "-i", - str(card_db), - "--card_annotation", - "card_database_v3.2.5.fasta", - "--local", - ], - "path_tmp", - verbose=True, - ) + # Run load_card_db two times with include_other_models set to True and False + for parameters in [False, True]: + load_card_db( + tmp="path_tmp", + card_db=card_db, + kmer_db=kmer_db, + kmer=True, + fasta=True, + include_wildcard=True, + include_other_models=parameters, + ) + + # Create two expected call objects + flags = ["", "_all_models"] + parameters = ["", "_all"] + + expected_calls = [ + call( + [ + "rgi", + "load", + "--card_json", + os.path.join(str(card_db), "card.json"), + "--local", + f"--card_annotation{flag}", + os.path.join( + str(card_db), f"card_database_v3.2.5{parameter}.fasta" + ), + f"--wildcard_annotation{flag}", + os.path.join( + str(card_db), f"wildcard_database_v0{parameter}.fasta" + ), + "--wildcard_index", + os.path.join(str(card_db), "index-for-model-sequences.txt"), + "--kmer_database", + os.path.join(str(kmer_db), "61_kmer_db.json"), + "--amr_kmers", + os.path.join(str(kmer_db), "all_amr_61mers.txt"), + "--kmer_size", + "61", + ], + "path_tmp", + verbose=True, + ) + for flag, parameter in zip(flags, parameters) + ] + + # Assert if function was called with expected calls + mock_run_command.assert_has_calls(expected_calls, any_order=False) def test_exception_raised(self): - tmp = "path/to/tmp" - card_db = "path/to/card_db.json" + # Simulate a subprocess.CalledProcessError during run_command expected_message = ( "An error was encountered while running rgi, " "(return code 1), please inspect stdout and stderr to learn more." ) - operation = "load" with patch( "q2_amr.card.utils.run_command" ) as mock_run_command, self.assertRaises(Exception) as cm: mock_run_command.side_effect = subprocess.CalledProcessError(1, "cmd") - load_preprocess_card_db(tmp, card_db, operation) + load_card_db() self.assertEqual(str(cm.exception), expected_message) def test_read_in_txt_mags(self): diff --git a/q2_amr/tests/data/DNA_fasta.fasta b/q2_amr/tests/data/DNA_fasta.fasta new file mode 100644 index 0000000..3de7be4 --- /dev/null +++ b/q2_amr/tests/data/DNA_fasta.fasta @@ -0,0 +1,2 @@ +>Prevalence_Sequence_ID:17933|ID:3252|Name:Campylobacter_jejuni_23S_rRNA_with_mutation_conferring_resistance_to_erythromycin|ARO:3004546 +AGCTACTAA diff --git a/q2_amr/tests/data/DNA_fasta_-.fasta b/q2_amr/tests/data/DNA_fasta_-.fasta new file mode 100644 index 0000000..33d85a3 --- /dev/null +++ b/q2_amr/tests/data/DNA_fasta_-.fasta @@ -0,0 +1,2 @@ +>Prevalence_Sequence_ID:17933|ID:3252|Name:Campylobacter_jejuni_23S_rRNA_with_mutation_conferring_resistance_to_erythromycin|ARO:3004546 +AGCTACTA-A diff --git a/q2_amr/tests/data/kmer_json_test.json b/q2_amr/tests/data/kmer_json_test.json new file mode 100644 index 0000000..4ae73fb --- /dev/null +++ b/q2_amr/tests/data/kmer_json_test.json @@ -0,0 +1 @@ +{"p": "dummy_value", "c": "dummy_value", "b": "dummy_value", "s": "dummy_value", "g": "dummy_value"} diff --git a/q2_amr/tests/data/kmer_txt_test.txt b/q2_amr/tests/data/kmer_txt_test.txt new file mode 100644 index 0000000..44c3277 --- /dev/null +++ b/q2_amr/tests/data/kmer_txt_test.txt @@ -0,0 +1,9 @@ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA 8 +TTTTCCCGCTGTGCGTGCTGCTGGCTTTCCTGGTGCTGGCCGCCCAATACGAAAGCTGGAG 1 +TACAGCAGCATCGAAGAAGCCTACATGGCGATCTTCCCACCGCCGCCGGTACAGGGCCTGG 3 +TCTCTCTATCAGATTGATCCTGCTACTTACCAGGCCTCTTATGAAAGTGCAAAAGGCGATC 1 +TAGTCGATCTGTCGTTGTTTACGTCGCGCAACTTCACCATCGGCTGCTTGTGTATCAGCCT 1 +AGCGAATTACTAGTTAAGGATAATGACCCTATCGCGCAAGATGTGTATGCCAAAGAAAAAG 2 +CCCGCGACACCACTCCCCCGGCCAGCATGGCCGCGACCCTGCGCAAGCTGCTGACCAGCCA 4 +CGGACGGATCAGCTTCGTCCTGATGGCAATGGCGGTCTTGTTTGCCGGTCTGATTGCCCGC 1 +ATCGAGGAAGTGGTGAAGACGCTGCTCGAGGGTATCGTCCTCGTGTTCCTCGTGATGTATC 252 diff --git a/q2_amr/tests/data/wildcard_data.tar.bz2 b/q2_amr/tests/data/wildcard_data.tar.bz2 new file mode 100644 index 0000000..4444ae0 Binary files /dev/null and b/q2_amr/tests/data/wildcard_data.tar.bz2 differ diff --git a/q2_amr/types/__init__.py b/q2_amr/types/__init__.py index ec51578..7f5d25d 100644 --- a/q2_amr/types/__init__.py +++ b/q2_amr/types/__init__.py @@ -17,12 +17,18 @@ CARDDatabaseFormat, CARDGeneAnnotationDirectoryFormat, CARDGeneAnnotationFormat, + CARDKmerDatabaseDirectoryFormat, + CARDKmerJSONFormat, + CARDKmerTXTFormat, + CARDWildcardIndexFormat, + GapDNAFASTAFormat, ) from ._type import ( CARDAlleleAnnotation, CARDAnnotation, CARDDatabase, CARDGeneAnnotation, + CARDKmerDatabase, ) __all__ = [ @@ -40,4 +46,10 @@ "CARDAnnotation", "CARDAlleleAnnotation", "CARDGeneAnnotation", + "CARDKmerDatabaseDirectoryFormat", + "CARDKmerJSONFormat", + "CARDKmerTXTFormat", + "GapDNAFASTAFormat", + "CARDWildcardIndexFormat", + "CARDKmerDatabase", ] diff --git a/q2_amr/types/_format.py b/q2_amr/types/_format.py index 56548d7..2eb3d36 100644 --- a/q2_amr/types/_format.py +++ b/q2_amr/types/_format.py @@ -6,10 +6,12 @@ # The full license is in the file LICENSE, distributed with this software. # ---------------------------------------------------------------------------- import json +import re from copy import copy import pandas as pd import qiime2.plugin.model as model +from q2_types.feature_data._format import DNAFASTAFormat from q2_types_genomics.per_sample_data._format import MultiDirValidationMixin from qiime2.plugin import ValidationError @@ -50,9 +52,117 @@ def _validate_(self, level): self._validate() -CARDDatabaseDirectoryFormat = model.SingleFileDirectoryFormat( - "CARDDatabaseDirectoryFormat", "card.json", CARDDatabaseFormat -) +class CARDWildcardIndexFormat(model.TextFileFormat): + def _validate(self, n_records=None): + header_exp = [ + "prevalence_sequence_id", + "model_id", + "aro_term", + "aro_accession", + "detection_model", + "species_name", + "ncbi_accession", + "data_type", + "rgi_criteria", + "percent_identity", + "bitscore", + "amr_gene_family", + "resistance_mechanism", + "drug_class", + "card_short_name", + ] + + df = pd.read_csv(str(self), sep="\t") + header_obs = list(df.columns) + if not set(header_exp).issubset(set(header_obs)): + raise ValidationError( + "Values do not match CARDWildcardindexFormat. Must contain" + "the following values: " + + ", ".join(header_exp) + + ".\n\nFound instead: " + + ", ".join(header_obs) + ) + + def _validate_(self, level): + self._validate() + + +class GapDNAFASTAFormat(DNAFASTAFormat): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.alphabet += "-" + + +class CARDDatabaseDirectoryFormat(model.DirectoryFormat): + card_fasta = model.File( + r"card_database_v\d+\.\d+\.\d+.fasta", format=DNAFASTAFormat + ) + card_fasta_all = model.File( + r"card_database_v\d+\.\d+\.\d+_all.fasta", format=GapDNAFASTAFormat + ) + wildcard = model.File("wildcard_database_v0.fasta", format=DNAFASTAFormat) + wildcard_all = model.File( + "wildcard_database_v0_all.fasta", format=GapDNAFASTAFormat + ) + card_json = model.File("card.json", format=CARDDatabaseFormat) + index = model.File("index-for-model-sequences.txt", format=CARDWildcardIndexFormat) + homolog_model = model.File( + "nucleotide_fasta_protein_homolog_model_variants.fasta", format=DNAFASTAFormat + ) + overexpression_model = model.File( + "nucleotide_fasta_protein_overexpression_model_variants.fasta", + format=DNAFASTAFormat, + ) + protein_model = model.File( + "nucleotide_fasta_protein_variant_model_variants.fasta", + format=DNAFASTAFormat, + ) + rRNA_model = model.File( + "nucleotide_fasta_rRNA_gene_variant_model_variants.fasta", + format=GapDNAFASTAFormat, + ) + + +class CARDKmerTXTFormat(model.TextFileFormat): + def _validate(self, n_records=None): + pattern = r"^[AGCT]+\t\d+$" + + with open(str(self), "r") as file: + lines = file.readlines()[:10] + for line in lines: + if not re.match(pattern, line.strip()): + raise ValidationError( + "The provided file is not the correct format. All lines must " + r"match the regex pattern r'^[AGCT]+\t\d+$'." + ) + + def _validate_(self, level): + self._validate() + + +class CARDKmerJSONFormat(model.TextFileFormat): + def _validate(self, n_records=None): + keys_exp = ["p", "c", "b", "s", "g"] + with open(str(self)) as json_file: + kmer_dict = json.load(json_file) + keys_obs = list(kmer_dict.keys()) + + if keys_obs != keys_exp: + raise ValidationError( + "Keys do not match KMERJSON format. Must consist of " + "the following values: " + + ", ".join(keys_exp) + + ".\n\nFound instead: " + + ", ".join(keys_obs) + ) + + def _validate_(self, level): + self._validate() + + +class CARDKmerDatabaseDirectoryFormat(model.DirectoryFormat): + kmer_json = model.File(r"\d+_kmer_db.json", format=CARDKmerJSONFormat) + kmer_fasta = model.File(r"all_amr_\d+mers.txt", format=CARDKmerTXTFormat) class CARDAnnotationTXTFormat(model.TextFileFormat): diff --git a/q2_amr/types/_transformer.py b/q2_amr/types/_transformer.py index 0700715..6d08cf1 100644 --- a/q2_amr/types/_transformer.py +++ b/q2_amr/types/_transformer.py @@ -213,23 +213,6 @@ def card_annotation_df_to_fasta(txt_file_path: str, seq_type: str): return fasta_format -def read_mapping_data(data_path, variant): - df_list = [] - for samp in os.listdir(str(data_path)): - file_path = os.path.join( - str(data_path), samp, f"{samp}.{variant}_mapping_data.txt" - ) - df = pd.read_csv(file_path, sep="\t") - df.insert(0, "Sample Name", samp) - df_list.append(df) - mapping_data_cat = pd.concat(df_list, axis=0) - mapping_data_cat.reset_index(inplace=True, drop=True) - mapping_data_cat.index.name = "id" - mapping_data_cat.index = mapping_data_cat.index.astype(str) - metadata = qiime2.Metadata(mapping_data_cat) - return metadata - - @plugin.register_transformer def _12(data: CARDAlleleAnnotationDirectoryFormat) -> qiime2.Metadata: return tabulate_data(data, "allele") diff --git a/q2_amr/types/_type.py b/q2_amr/types/_type.py index c281c4e..aa2f302 100644 --- a/q2_amr/types/_type.py +++ b/q2_amr/types/_type.py @@ -9,6 +9,7 @@ from qiime2.plugin import SemanticType CARDDatabase = SemanticType("CARDDatabase") +CARDKmerDatabase = SemanticType("CARDKmerDatabase") CARDAnnotation = SemanticType("CARDAnnotation", variant_of=SampleData.field["type"]) CARDAlleleAnnotation = SemanticType( "CARDAlleleAnnotation", variant_of=SampleData.field["type"] diff --git a/q2_amr/types/tests/data/DNA_fasta.fasta b/q2_amr/types/tests/data/DNA_fasta.fasta new file mode 100644 index 0000000..3de7be4 --- /dev/null +++ b/q2_amr/types/tests/data/DNA_fasta.fasta @@ -0,0 +1,2 @@ +>Prevalence_Sequence_ID:17933|ID:3252|Name:Campylobacter_jejuni_23S_rRNA_with_mutation_conferring_resistance_to_erythromycin|ARO:3004546 +AGCTACTAA diff --git a/q2_amr/types/tests/data/DNA_fasta_-.fasta b/q2_amr/types/tests/data/DNA_fasta_-.fasta new file mode 100644 index 0000000..33d85a3 --- /dev/null +++ b/q2_amr/types/tests/data/DNA_fasta_-.fasta @@ -0,0 +1,2 @@ +>Prevalence_Sequence_ID:17933|ID:3252|Name:Campylobacter_jejuni_23S_rRNA_with_mutation_conferring_resistance_to_erythromycin|ARO:3004546 +AGCTACTA-A diff --git a/q2_amr/types/tests/data/index-for-model-sequences-test.txt b/q2_amr/types/tests/data/index-for-model-sequences-test.txt new file mode 100644 index 0000000..636ab96 --- /dev/null +++ b/q2_amr/types/tests/data/index-for-model-sequences-test.txt @@ -0,0 +1,2 @@ +prevalence_sequence_id model_id aro_term aro_accession detection_model species_name ncbi_accession data_type rgi_criteria percent_identity bitscore amr_gene_family resistance_mechanism drug_class card_short_name +1 5731 qacG ARO:3007015 protein homolog model Acinetobacter haemolyticus NZ_WTTI01 ncbi_contig Strict 41.9 87.8 small multidrug resistance (SMR) antibiotic efflux pump antibiotic efflux disinfecting agents and antiseptics qacG diff --git a/q2_amr/types/tests/data/kmer_json_test.json b/q2_amr/types/tests/data/kmer_json_test.json new file mode 100644 index 0000000..4ae73fb --- /dev/null +++ b/q2_amr/types/tests/data/kmer_json_test.json @@ -0,0 +1 @@ +{"p": "dummy_value", "c": "dummy_value", "b": "dummy_value", "s": "dummy_value", "g": "dummy_value"} diff --git a/q2_amr/types/tests/data/kmer_txt_test.txt b/q2_amr/types/tests/data/kmer_txt_test.txt new file mode 100644 index 0000000..44c3277 --- /dev/null +++ b/q2_amr/types/tests/data/kmer_txt_test.txt @@ -0,0 +1,9 @@ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA 8 +TTTTCCCGCTGTGCGTGCTGCTGGCTTTCCTGGTGCTGGCCGCCCAATACGAAAGCTGGAG 1 +TACAGCAGCATCGAAGAAGCCTACATGGCGATCTTCCCACCGCCGCCGGTACAGGGCCTGG 3 +TCTCTCTATCAGATTGATCCTGCTACTTACCAGGCCTCTTATGAAAGTGCAAAAGGCGATC 1 +TAGTCGATCTGTCGTTGTTTACGTCGCGCAACTTCACCATCGGCTGCTTGTGTATCAGCCT 1 +AGCGAATTACTAGTTAAGGATAATGACCCTATCGCGCAAGATGTGTATGCCAAAGAAAAAG 2 +CCCGCGACACCACTCCCCCGGCCAGCATGGCCGCGACCCTGCGCAAGCTGCTGACCAGCCA 4 +CGGACGGATCAGCTTCGTCCTGATGGCAATGGCGGTCTTGTTTGCCGGTCTGATTGCCCGC 1 +ATCGAGGAAGTGGTGAAGACGCTGCTCGAGGGTATCGTCCTCGTGTTCCTCGTGATGTATC 252 diff --git a/q2_amr/types/tests/test_types_formats_transformers.py b/q2_amr/types/tests/test_types_formats_transformers.py index db5414b..a594eaa 100644 --- a/q2_amr/types/tests/test_types_formats_transformers.py +++ b/q2_amr/types/tests/test_types_formats_transformers.py @@ -7,6 +7,7 @@ # ---------------------------------------------------------------------------- import json import os +import shutil import tempfile import pandas as pd @@ -25,12 +26,18 @@ from q2_amr.types import ( CARDAlleleAnnotationDirectoryFormat, + CARDDatabaseDirectoryFormat, CARDGeneAnnotationDirectoryFormat, ) from q2_amr.types._format import ( CARDAnnotationDirectoryFormat, CARDAnnotationTXTFormat, CARDDatabaseFormat, + CARDKmerDatabaseDirectoryFormat, + CARDKmerJSONFormat, + CARDKmerTXTFormat, + CARDWildcardIndexFormat, + GapDNAFASTAFormat, ) from q2_amr.types._transformer import ( _read_from_card_file, @@ -59,6 +66,48 @@ def test_card_database_format_validate_positive(self): format = CARDDatabaseFormat(filepath, mode="r") format.validate() + def test_wildcard_index_format_validate_positive(self): + filepath = self.get_data_path("index-for-model-sequences-test.txt") + format = CARDWildcardIndexFormat(filepath, mode="r") + format.validate() + + def test_extended_dna_fasta_format_validate_positive(self): + filepath = self.get_data_path("DNA_fasta_-.fasta") + format = GapDNAFASTAFormat(filepath, mode="r") + format.validate() + + def test_card_database_directory_format_validate_positive(self): + src_des_list = [ + ("card_test.json", "card.json"), + ("DNA_fasta.fasta", "card_database_v3.2.7.fasta"), + ("DNA_fasta_-.fasta", "card_database_v3.2.7_all.fasta"), + ("DNA_fasta.fasta", "wildcard_database_v0.fasta"), + ("DNA_fasta_-.fasta", "wildcard_database_v0_all.fasta"), + ("index-for-model-sequences-test.txt", "index-for-model-sequences.txt"), + ( + "DNA_fasta.fasta", + "nucleotide_fasta_protein_homolog_model_variants.fasta", + ), + ( + "DNA_fasta.fasta", + "nucleotide_fasta_protein_overexpression_model_variants.fasta", + ), + ( + "DNA_fasta.fasta", + "nucleotide_fasta_protein_variant_model_variants.fasta", + ), + ( + "DNA_fasta_-.fasta", + "nucleotide_fasta_rRNA_gene_variant_model_variants.fasta", + ), + ] + for scr_file, des_file in src_des_list: + shutil.copy( + self.get_data_path(scr_file), os.path.join(self.temp_dir.name, des_file) + ) + format = CARDDatabaseDirectoryFormat(self.temp_dir.name, mode="r") + format.validate() + def test_dataframe_to_card_format_transformer(self): filepath = self.get_data_path("card_test.json") transformer = self.get_transformer(pd.DataFrame, CARDDatabaseFormat) @@ -164,6 +213,30 @@ def test_read_from_card_generator(self): self.assertIsInstance(generator, ProteinIterator) +class TestCARDCARDKmerDirectoryTypesAndFormats(AMRTypesTestPluginBase): + def test_kmer_txt_format_validate_positive(self): + filepath = self.get_data_path("kmer_txt_test.txt") + format = CARDKmerTXTFormat(filepath, mode="r") + format.validate() + + def test_kmer_json_format_validate_positive(self): + filepath = self.get_data_path("kmer_json_test.json") + format = CARDKmerJSONFormat(filepath, mode="r") + format.validate() + + def test_card_kmer_database_directory_format_validate_positive(self): + src_des_list = [ + ("kmer_json_test.json", "61_kmer_db.json"), + ("kmer_txt_test.txt", "all_amr_61mers.txt"), + ] + for scr_file, des_file in src_des_list: + shutil.copy( + self.get_data_path(scr_file), os.path.join(self.temp_dir.name, des_file) + ) + format = CARDKmerDatabaseDirectoryFormat(self.temp_dir.name, mode="r") + format.validate() + + class TestCARDMagsAnnotationTypesAndFormats(AMRTypesTestPluginBase): def test_df_to_card_annotation_format_transformer(self): filepath = self.get_data_path("rgi_output.txt") diff --git a/setup.py b/setup.py index dbe7488..048b094 100644 --- a/setup.py +++ b/setup.py @@ -25,9 +25,16 @@ package_data={ "q2_amr": [ "citations.bib", + "tests/data/*", "assets/rgi/annotation_stats/*", "assets/rgi/heatmap/*", ], + "q2_amr.types.tests": [ + "data/*", + "data/annotate_mags_output/*/*/*", + "data/annotate_reads_output/*/*", + ], + "q2_amr.tests": ["data/*"], }, zip_safe=False, )