diff --git a/q2_amr/card/mags.py b/q2_amr/card/mags.py index 97e96fa..0ca6832 100644 --- a/q2_amr/card/mags.py +++ b/q2_amr/card/mags.py @@ -24,7 +24,7 @@ def annotate_mags_card( amr_annotations = CARDAnnotationDirectoryFormat() frequency_list = [] with tempfile.TemporaryDirectory() as tmp: - load_card_db(tmp, card_db, "load", False, False) + load_card_db(tmp=tmp, card_db=card_db) for samp_bin in list(manifest.index): bin_dir = os.path.join(str(amr_annotations), samp_bin[0], samp_bin[1]) os.makedirs(bin_dir, exist_ok=True) diff --git a/q2_amr/card/reads.py b/q2_amr/card/reads.py index e134818..724cbc0 100644 --- a/q2_amr/card/reads.py +++ b/q2_amr/card/reads.py @@ -43,8 +43,13 @@ def annotate_reads_card( amr_allele_annotation = CARDAlleleAnnotationDirectoryFormat() amr_gene_annotation = CARDGeneAnnotationDirectoryFormat() with tempfile.TemporaryDirectory() as tmp: - load_card_db(tmp, card_db, "load", False, False) - load_card_db(tmp, card_db, "load_fasta", include_other_models, include_wildcard) + load_card_db( + tmp=tmp, + card_db=card_db, + fasta=True, + include_other_models=include_other_models, + include_wildcard=include_wildcard, + ) for samp in list(manifest.index): fwd = manifest.loc[samp, "forward"] rev = manifest.loc[samp, "reverse"] if paired else None diff --git a/q2_amr/card/utils.py b/q2_amr/card/utils.py index 432e939..da21431 100644 --- a/q2_amr/card/utils.py +++ b/q2_amr/card/utils.py @@ -1,3 +1,4 @@ +import glob import json import os import subprocess @@ -22,42 +23,55 @@ def run_command(cmd, cwd, verbose=True): subprocess.run(cmd, check=True, cwd=cwd) -def load_card_db(tmp, card_db, operation, all_models, wildcard): +def load_card_db( + tmp, + card_db, + kmer_db=None, + kmer: bool = False, + fasta: bool = False, + include_other_models: bool = False, + include_wildcard: bool = False, +): path_card_json = os.path.join(str(card_db), "card.json") - if operation == "load": - cmd = ["rgi", "load", "--card_json", path_card_json, "--local"] - elif operation == "load_fasta": + cmd = ["rgi", "load", "--card_json", path_card_json, "--local"] + models = ("_all", "_all_models") if include_other_models is True else ("", "") + if fasta: with open(path_card_json) as f: card_data = json.load(f) version = card_data["_version"] - models = ("_all", "_all_models") if all_models is True else ("", "") path_card_fasta = os.path.join( str(card_db), f"card_database_v{version}{models[0]}.fasta" ) - cmd = [ - "rgi", - "load", - "-i", - path_card_json, - f"--card_annotation{models[1]}", - path_card_fasta, - "--local", - ] - if wildcard: - path_wildcard_fasta = os.path.join( - str(card_db), f"wildcard_database_v0{models[0]}.fasta" - ) - path_wildcard_index = os.path.join( - str(card_db), "index-for-model-sequences.txt" - ) - cmd.extend( - [ - f"--wildcard_annotation{models[1]}", - path_wildcard_fasta, - "--wildcard_index", - path_wildcard_index, - ] - ) + cmd.extend([f"--card_annotation{models[1]}", path_card_fasta]) + if include_wildcard: + path_wildcard_fasta = os.path.join( + str(card_db), f"wildcard_database_v0{models[0]}.fasta" + ) + path_wildcard_index = os.path.join( + str(card_db), "index-for-model-sequences.txt" + ) + cmd.extend( + [ + f"--wildcard_annotation{models[1]}", + path_wildcard_fasta, + "--wildcard_index", + path_wildcard_index, + ] + ) + if kmer: + path_kmer_json = glob.glob(os.path.join(str(kmer_db), "*_kmer_db.json"))[0] + path_kmer_txt = glob.glob(os.path.join(str(kmer_db), "all_amr_*mers.txt"))[0] + kmer_size = os.path.basename(path_kmer_json).split("_")[0] + cmd.extend( + [ + "--kmer_database", + path_kmer_json, + "--amr_kmers", + path_kmer_txt, + "--kmer_size", + kmer_size, + ] + ) try: run_command(cmd, tmp, verbose=True) diff --git a/q2_amr/tests/card/test_utils.py b/q2_amr/tests/card/test_utils.py index 6755324..da73615 100644 --- a/q2_amr/tests/card/test_utils.py +++ b/q2_amr/tests/card/test_utils.py @@ -8,7 +8,7 @@ from test_mags import TestAnnotateMagsCard from q2_amr.card.utils import create_count_table, load_card_db, read_in_txt -from q2_amr.types import CARDDatabaseDirectoryFormat, CARDDatabaseFormat +from q2_amr.types import CARDDatabaseDirectoryFormat, CARDKmerDatabaseDirectoryFormat class TestAnnotateReadsCARD(TestPluginBase): @@ -35,66 +35,55 @@ class TestAnnotateReadsCARD(TestPluginBase): } ) - def test_load_card_db(self): - card_db = CARDDatabaseFormat() - path_card_json = os.path.join(str(card_db), "card.json") - with patch("q2_amr.card.utils.run_command") as mock_run_command: - load_card_db("path_tmp", card_db, "load", False, False) - mock_run_command.assert_called_once_with( - ["rgi", "load", "--card_json", path_card_json, "--local"], - "path_tmp", - verbose=True, - ) - def test_load_card_db_fasta(self): card_db = CARDDatabaseDirectoryFormat() + kmer_db = CARDKmerDatabaseDirectoryFormat() card_json = self.get_data_path("card_test.json") shutil.copy(card_json, os.path.join(str(card_db), "card.json")) + kmer_txt = self.get_data_path("kmer_txt_test.txt") + shutil.copy(kmer_txt, os.path.join(str(kmer_db), "all_amr_61mers.txt")) + kmer_json = self.get_data_path("kmer_json_test.json") + shutil.copy(kmer_json, os.path.join(str(kmer_db), "61_kmer_db.json")) + with patch("q2_amr.card.utils.run_command") as mock_run_command: - load_card_db("path_tmp", card_db, "load_fasta", False, False) - load_card_db("path_tmp", card_db, "load_fasta", True, False) - load_card_db("path_tmp", card_db, "load_fasta", False, True) - load_card_db("path_tmp", card_db, "load_fasta", True, True) + load_card_db( + tmp="path_tmp", + card_db=card_db, + kmer_db=kmer_db, + kmer=True, + fasta=True, + include_other_models=False, + include_wildcard=True, + ) + load_card_db( + tmp="path_tmp", + card_db=card_db, + kmer_db=kmer_db, + kmer=True, + fasta=True, + include_other_models=True, + include_wildcard=True, + ) expected_calls = [ call( [ "rgi", "load", - "-i", - os.path.join(str(card_db), "card.json"), - "--card_annotation", - os.path.join(str(card_db), "card_database_v3.2.5.fasta"), - "--local", - ], - "path_tmp", - verbose=True, - ), - call( - [ - "rgi", - "load", - "-i", + "--card_json", os.path.join(str(card_db), "card.json"), - "--card_annotation_all_models", - os.path.join(str(card_db), "card_database_v3.2.5_all.fasta"), "--local", - ], - "path_tmp", - verbose=True, - ), - call( - [ - "rgi", - "load", - "-i", - os.path.join(str(card_db), "card.json"), "--card_annotation", os.path.join(str(card_db), "card_database_v3.2.5.fasta"), - "--local", "--wildcard_annotation", os.path.join(str(card_db), "wildcard_database_v0.fasta"), "--wildcard_index", os.path.join(str(card_db), "index-for-model-sequences.txt"), + "--kmer_database", + os.path.join(str(kmer_db), "61_kmer_db.json"), + "--amr_kmers", + os.path.join(str(kmer_db), "all_amr_61mers.txt"), + "--kmer_size", + "61", ], "path_tmp", verbose=True, @@ -103,15 +92,21 @@ def test_load_card_db_fasta(self): [ "rgi", "load", - "-i", + "--card_json", os.path.join(str(card_db), "card.json"), + "--local", "--card_annotation_all_models", os.path.join(str(card_db), "card_database_v3.2.5_all.fasta"), - "--local", "--wildcard_annotation_all_models", os.path.join(str(card_db), "wildcard_database_v0_all.fasta"), "--wildcard_index", os.path.join(str(card_db), "index-for-model-sequences.txt"), + "--kmer_database", + os.path.join(str(kmer_db), "61_kmer_db.json"), + "--amr_kmers", + os.path.join(str(kmer_db), "all_amr_61mers.txt"), + "--kmer_size", + "61", ], "path_tmp", verbose=True, @@ -127,12 +122,11 @@ def test_exception_raised(self): "An error was encountered while running rgi, " "(return code 1), please inspect stdout and stderr to learn more." ) - operation = "load" with patch( "q2_amr.card.utils.run_command" ) as mock_run_command, self.assertRaises(Exception) as cm: mock_run_command.side_effect = subprocess.CalledProcessError(1, "cmd") - load_card_db(tmp, card_db, operation) + load_card_db(tmp, card_db) self.assertEqual(str(cm.exception), expected_message) def test_read_in_txt_mags(self): diff --git a/q2_amr/tests/data/kmer_json_test.json b/q2_amr/tests/data/kmer_json_test.json new file mode 100644 index 0000000..4ae73fb --- /dev/null +++ b/q2_amr/tests/data/kmer_json_test.json @@ -0,0 +1 @@ +{"p": "dummy_value", "c": "dummy_value", "b": "dummy_value", "s": "dummy_value", "g": "dummy_value"} diff --git a/q2_amr/tests/data/kmer_txt_test.txt b/q2_amr/tests/data/kmer_txt_test.txt new file mode 100644 index 0000000..44c3277 --- /dev/null +++ b/q2_amr/tests/data/kmer_txt_test.txt @@ -0,0 +1,9 @@ +AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA 8 +TTTTCCCGCTGTGCGTGCTGCTGGCTTTCCTGGTGCTGGCCGCCCAATACGAAAGCTGGAG 1 +TACAGCAGCATCGAAGAAGCCTACATGGCGATCTTCCCACCGCCGCCGGTACAGGGCCTGG 3 +TCTCTCTATCAGATTGATCCTGCTACTTACCAGGCCTCTTATGAAAGTGCAAAAGGCGATC 1 +TAGTCGATCTGTCGTTGTTTACGTCGCGCAACTTCACCATCGGCTGCTTGTGTATCAGCCT 1 +AGCGAATTACTAGTTAAGGATAATGACCCTATCGCGCAAGATGTGTATGCCAAAGAAAAAG 2 +CCCGCGACACCACTCCCCCGGCCAGCATGGCCGCGACCCTGCGCAAGCTGCTGACCAGCCA 4 +CGGACGGATCAGCTTCGTCCTGATGGCAATGGCGGTCTTGTTTGCCGGTCTGATTGCCCGC 1 +ATCGAGGAAGTGGTGAAGACGCTGCTCGAGGGTATCGTCCTCGTGTTCCTCGTGATGTATC 252