Skip to content

Commit

Permalink
FIX: Fixed q2-amr failing, if no ARGs are detected (#9)
Browse files Browse the repository at this point in the history
-Fixed the problem of q2-amr failing if no ARGs are detected with annotate_reads_card or annotate_mags_card methods. Now the user is informed that no ARGs where detected and no further output is produced.
-Removed mapq, mapped and coverage parameters from annotate_reads_card method.
-Added CARDGeneAnnotationDirectoryFormat/CARDGeneAnnotationDirectoryFormat-> qiime2.Metadata transformer .
-Fixed bug in heatmap method. Parameters clus and cat can not be used at the same time.
  • Loading branch information
VinzentRisch authored Oct 16, 2023
1 parent 54d40c7 commit 6844065
Show file tree
Hide file tree
Showing 25 changed files with 608 additions and 280 deletions.
22 changes: 17 additions & 5 deletions q2_amr/assets/rgi/heatmap/index.html
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
{% extends 'tabbed.html' %}
{% extends 'base.html' %}

{% block tabcontent %}
{% block content %}

<div class="row">
<div class="col-lg-12">
<h3>Downloads</h3>
<h4>Downloads</h4>
</div>
</div>

Expand All @@ -16,15 +16,27 @@ <h3>Downloads</h3>
</div>
</div>
</div>
<br>

<br>

<div class="row">
<div class="col-lg-6">
<h3>My Picture</h3>
<h4>CARD annotation: heatmap</h4>
<img src="rgi_data/heatmap.png" alt="My Picture">
</div>
</div>

<br>

<div class="row mt-4">
<div class="col-lg-12">
<p>Displayed is a heatmap of annotate-mags-card output. <br>
Yellow represents a perfect hit, teal represents a <br>
strict hit, purple represents no hit.</p>
</div>
</div>


{% endblock %}

{% block footer %}
Expand Down
102 changes: 102 additions & 0 deletions q2_amr/card/heatmap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import glob
import os
import shutil
import subprocess
import tempfile
from distutils.dir_util import copy_tree

import pkg_resources
import q2templates

from q2_amr.card.utils import run_command
from q2_amr.types import CARDAnnotationDirectoryFormat


def heatmap(
output_dir: str,
amr_annotation: CARDAnnotationDirectoryFormat,
clus: str = None,
cat: str = None,
display: str = "plain",
frequency: bool = False,
):
TEMPLATES = pkg_resources.resource_filename("q2_amr", "assets")
annotation_dir = str(amr_annotation)
with tempfile.TemporaryDirectory() as tmp:
results_dir = os.path.join(tmp, "results")
json_files_dir = os.path.join(tmp, "json_files")
os.makedirs(results_dir)
os.makedirs(json_files_dir)
for json_file in glob.glob(os.path.join(annotation_dir, "*", "*", "*.json")):
sample, bin_name, _ = json_file.split(os.path.sep)[-3:]
destination_path = os.path.join(json_files_dir, f"{sample}_{bin_name}")
shutil.copy(json_file, destination_path)

run_rgi_heatmap(tmp, json_files_dir, clus, cat, display, frequency)
change_names(results_dir)
copy_tree(os.path.join(TEMPLATES, "rgi", "heatmap"), output_dir)
copy_tree(results_dir, os.path.join(output_dir, "rgi_data"))
context = {"tabs": [{"title": "Heatmap", "url": "index.html"}]}
index = os.path.join(TEMPLATES, "rgi", "heatmap", "index.html")
templates = [index]
q2templates.render(templates, output_dir, context=context)


class InvalidParameterCombinationError(Exception):
def __init__(self, message="Invalid parameter combination"):
self.message = message
super().__init__(self.message)


def run_rgi_heatmap(tmp, json_files_dir, clus, cat, display, frequency):
cmd = [
"rgi",
"heatmap",
"--input",
json_files_dir,
"--output",
f"{tmp}/results/heatmap",
"--display",
display,
]
if (clus == "both" or clus == "genes") and cat:
raise InvalidParameterCombinationError(
"If the parameter clus is set to genes "
"or both it is not possible to use the "
"cat parameter"
)
if clus:
cmd.extend(["--clus", clus])
if cat:
cmd.extend(["--cat", cat])
if frequency:
cmd.append("--frequency")
try:
run_command(cmd, tmp, verbose=True)
except subprocess.CalledProcessError as e:
raise Exception(
"An error was encountered while running rgi, "
f"(return code {e.returncode}), please inspect "
"stdout and stderr to learn more."
)


def change_names(results_dir):
"""
This function changes the names of the output files of the "rgi heatmap" function.
The output files are called heatmap-*.extension with * being the number of samples
included in the heatmap. The files are changed to heatmap.extension so that they
can be accessed in the index.html file more easily.
Parameters:
- results_dir (str): The directory where the files are stored.
"""
extensions = [".eps", ".csv", ".png"]
files = os.listdir(results_dir)
for filename in files:
if os.path.splitext(filename)[1] in extensions:
file_ext = os.path.splitext(filename)[1]
new_filename = "heatmap" + file_ext
old_path = os.path.join(results_dir, filename)
new_path = os.path.join(results_dir, new_filename)
os.rename(old_path, new_path)
39 changes: 27 additions & 12 deletions q2_amr/card/mags.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,28 @@
import pandas as pd
from q2_types_genomics.per_sample_data import MultiMAGSequencesDirFmt

from q2_amr.card.utils import load_preprocess_card_db, run_command
from q2_amr.card.utils import (
create_count_table,
load_preprocess_card_db,
read_in_txt,
run_command,
)
from q2_amr.types import CARDAnnotationDirectoryFormat, CARDDatabaseFormat


def annotate_mags_card(
mag: MultiMAGSequencesDirFmt,
card_db: CARDDatabaseFormat,
alignment_tool: str = "BLAST",
input_type: str = "contig",
split_prodigal_jobs: bool = False,
include_loose: bool = False,
include_nudge: bool = False,
low_quality: bool = False,
num_threads: int = 1,
) -> CARDAnnotationDirectoryFormat:
threads: int = 1,
) -> (CARDAnnotationDirectoryFormat, pd.DataFrame):
manifest = mag.manifest.view(pd.DataFrame)
amr_annotations = CARDAnnotationDirectoryFormat()
frequency_list = []
with tempfile.TemporaryDirectory() as tmp:
load_preprocess_card_db(tmp, card_db, "load")
for samp_bin in list(manifest.index):
Expand All @@ -33,24 +38,34 @@ def annotate_mags_card(
tmp,
input_sequence,
alignment_tool,
input_type,
split_prodigal_jobs,
include_loose,
include_nudge,
low_quality,
num_threads,
threads,
)
txt_path = os.path.join(bin_dir, "amr_annotation.txt")
json_path = os.path.join(bin_dir, "amr_annotation.json")

shutil.move(f"{tmp}/output.txt", txt_path)
shutil.move(f"{tmp}/output.json", json_path)
samp_bin_name = os.path.join(samp_bin[0], samp_bin[1])
frequency_df = read_in_txt(
path=txt_path, col_name="ARO", samp_bin_name=samp_bin_name
)
shutil.move(f"{tmp}/output.txt", f"{bin_dir}/amr_annotation.txt")
shutil.move(f"{tmp}/output.json", f"{bin_dir}/amr_annotation.json")
print("a")
return amr_annotations
if frequency_df is not None:
frequency_list.append(frequency_df)
feature_table = create_count_table(df_list=frequency_list)
return (
amr_annotations,
feature_table,
)


def run_rgi_main(
tmp,
input_sequence: str,
alignment_tool: str = "BLAST",
input_type: str = "contig",
split_prodigal_jobs: bool = False,
include_loose: bool = False,
include_nudge: bool = False,
Expand All @@ -69,7 +84,7 @@ def run_rgi_main(
"--alignment_tool",
alignment_tool,
"--input_type",
input_type,
"contig",
"--local",
]
if include_loose:
Expand Down
70 changes: 18 additions & 52 deletions q2_amr/card/reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import subprocess
import tempfile
from distutils.dir_util import copy_tree
from functools import reduce
from typing import Union

import altair as alt
Expand All @@ -15,7 +14,12 @@
SingleLanePerSampleSingleEndFastqDirFmt,
)

from q2_amr.card.utils import load_preprocess_card_db, run_command
from q2_amr.card.utils import (
create_count_table,
load_preprocess_card_db,
read_in_txt,
run_command,
)
from q2_amr.types import (
CARDAlleleAnnotationDirectoryFormat,
CARDDatabaseFormat,
Expand All @@ -30,10 +34,6 @@ def annotate_reads_card(
card_db: CARDDatabaseFormat,
aligner: str = "kma",
threads: int = 1,
include_baits: bool = False,
mapq: float = None,
mapped: float = None,
coverage: float = None,
) -> (
CARDAlleleAnnotationDirectoryFormat,
CARDGeneAnnotationDirectoryFormat,
Expand Down Expand Up @@ -65,15 +65,19 @@ def annotate_reads_card(
rev=rev,
aligner=aligner,
threads=threads,
include_baits=include_baits,
mapq=mapq,
mapped=mapped,
coverage=coverage,
)
allele_frequency = read_in_txt(samp_input_dir, "allele")
allele_frequency_list.append(allele_frequency)
gene_frequency = read_in_txt(samp_input_dir, "gene")
gene_frequency_list.append(gene_frequency)
path_allele = os.path.join(samp_input_dir, "output.allele_mapping_data.txt")
allele_frequency = read_in_txt(
path=path_allele, col_name="ARO Accession", samp_bin_name=samp
)
if allele_frequency is not None:
allele_frequency_list.append(allele_frequency)
path_gene = os.path.join(samp_input_dir, "output.gene_mapping_data.txt")
gene_frequency = read_in_txt(
path=path_gene, col_name="ARO Accession", samp_bin_name=samp
)
if gene_frequency is not None:
gene_frequency_list.append(gene_frequency)
move_files(samp_input_dir, samp_allele_dir, "allele")
move_files(samp_input_dir, samp_gene_dir, "gene")

Expand All @@ -98,43 +102,13 @@ def move_files(source_dir: str, des_dir: str, map_type: str):
)


def create_count_table(df_list: list) -> pd.DataFrame:
df_merged = reduce(
lambda left, right: pd.merge(left, right, on="ARO Accession", how="outer"),
df_list,
)
df_transposed = df_merged.transpose()
df_transposed = df_transposed.fillna(0)
df_transposed.columns = df_transposed.iloc[0]
df_transposed = df_transposed.drop("ARO Accession")
df_transposed.columns.name = None
df_transposed.index.name = "sample_id"
return df_transposed


def read_in_txt(samp_dir: str, map_type: str):
df = pd.read_csv(
os.path.join(samp_dir, f"output.{map_type}_mapping_data.txt"), sep="\t"
)
df = df[["ARO Accession"]]
df = df.astype(str)
samp = os.path.basename(samp_dir)
df[samp] = df.groupby("ARO Accession")["ARO Accession"].transform("count")
df = df.drop_duplicates(subset=["ARO Accession"])
return df


def run_rgi_bwt(
cwd: str,
samp: str,
fwd: str,
rev: str,
aligner: str,
threads: int,
include_baits: bool,
mapq: float,
mapped: float,
coverage: float,
):
cmd = [
"rgi",
Expand All @@ -152,14 +126,6 @@ def run_rgi_bwt(
]
if rev:
cmd.extend(["--read_two", rev])
if include_baits:
cmd.append("--include_baits")
if mapq:
cmd.extend(["--mapq", str(mapq)])
if mapped:
cmd.extend(["--mapped", str(mapped)])
if coverage:
cmd.extend(["--coverage", str(coverage)])
try:
run_command(cmd, cwd, verbose=True)
except subprocess.CalledProcessError as e:
Expand Down
32 changes: 32 additions & 0 deletions q2_amr/card/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import json
import subprocess
from functools import reduce

import pandas as pd

EXTERNAL_CMD_WARNING = (
"Running external command line application(s). "
Expand Down Expand Up @@ -44,3 +47,32 @@ def load_preprocess_card_db(tmp, card_db, operation):
f"(return code {e.returncode}), please inspect "
"stdout and stderr to learn more."
)


def read_in_txt(path: str, col_name: str, samp_bin_name: str):
df = pd.read_csv(path, sep="\t")
if df.empty:
return None
df = df[[col_name]]
df = df.astype(str)
df[samp_bin_name] = df.groupby(col_name)[col_name].transform("count")
df = df.drop_duplicates(subset=[col_name])
return df


def create_count_table(df_list: list) -> pd.DataFrame:
if not df_list:
raise ValueError(
"RGI did not identify any AMR genes. No output can be created."
)
df = reduce(
lambda left, right: pd.merge(left, right, on=left.columns[0], how="outer"),
df_list,
)
df = df.transpose()
df = df.fillna(0)
df.columns = df.iloc[0]
df = df.drop(df.index[0])
df.columns.name = None
df.index.name = "sample_id"
return df
Loading

0 comments on commit 6844065

Please sign in to comment.