Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: parallelize Kaiju classification #200

Merged
merged 11 commits into from
Oct 3, 2024
4 changes: 2 additions & 2 deletions q2_moshpit/kaiju/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
# ----------------------------------------------------------------------------

from .database import fetch_kaiju_db
from .classification import classify_kaiju
from .classification import classify_kaiju, _classify_kaiju

__all__ = ["fetch_kaiju_db", "classify_kaiju"]
__all__ = ["fetch_kaiju_db", "classify_kaiju", "_classify_kaiju"]
55 changes: 52 additions & 3 deletions q2_moshpit/kaiju/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
from q2_types.per_sample_sequences import (
SingleLanePerSamplePairedEndFastqDirFmt,
SingleLanePerSampleSingleEndFastqDirFmt,
SequencesWithQuality,
PairedEndSequencesWithQuality,
)

from q2_moshpit._utils import run_command
from q2_types.kaiju import KaijuDBDirectoryFormat
from q2_types.sample_data import SampleData

DEFAULT_PREFIXES = ["d__", "p__", "c__", "o__", "f__", "g__", "s__", "ssp__"]

Expand Down Expand Up @@ -195,7 +198,7 @@ def _process_kaiju_reports(tmpdir, all_args):
return _construct_feature_table(table_fp)


def _classify_kaiju(
def _classify_kaiju_helper(
manifest: pd.DataFrame, all_args: dict
) -> (pd.DataFrame, pd.DataFrame):
"""
Expand Down Expand Up @@ -255,7 +258,7 @@ def _classify_kaiju(
return table, taxonomy


def classify_kaiju(
def _classify_kaiju(
seqs: Union[
SingleLanePerSamplePairedEndFastqDirFmt,
SingleLanePerSampleSingleEndFastqDirFmt,
Expand All @@ -274,4 +277,50 @@ def classify_kaiju(
u: bool = False,
) -> (pd.DataFrame, pd.DataFrame):
manifest: pd.DataFrame = seqs.manifest.view(pd.DataFrame)
return _classify_kaiju(manifest, dict(locals().items()))
return _classify_kaiju_helper(manifest, dict(locals().items()))


def classify_kaiju(
ctx,
seqs,
db,
z=1,
a="greedy",
e=3,
m=11,
s=65,
evalue=0.01,
x=True,
r="species",
c=0.0,
exp=False,
u=False,
num_partitions=None
):
kwargs = {k: v for k, v in locals().items()
if k not in ["seqs", "db", "ctx", "num_partitions"]}

_classify_kaiju = ctx.get_action("moshpit", "_classify_kaiju")
collate_feature_tables = ctx.get_action("feature_table", "merge")
collate_taxonomies = ctx.get_action("feature_table", "merge_taxa")

if seqs.type <= SampleData[SequencesWithQuality]:
partition_method = ctx.get_action("demux", "partition_samples_single")
elif seqs.type <= SampleData[PairedEndSequencesWithQuality]:
partition_method = ctx.get_action("demux", "partition_samples_paired")
else:
raise NotImplementedError()

(partitioned_seqs,) = partition_method(seqs, num_partitions)

tables = []
taxonomies = []
for seq in partitioned_seqs.values():
(table, taxonomy) = _classify_kaiju(seq, db, **kwargs)
tables.append(table)
taxonomies.append(taxonomy)

(combined_table,) = collate_feature_tables(tables)
(collated_taxonomy,) = collate_taxonomies(taxonomies)

return combined_table, collated_taxonomy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
181 changes: 175 additions & 6 deletions q2_moshpit/kaiju/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,53 @@
import os
import unittest
from subprocess import CalledProcessError
from unittest.mock import patch, Mock, ANY
from unittest.mock import patch, Mock, ANY, MagicMock, call

import numpy as np
import pandas as pd
import qiime2
from pandas._testing import assert_frame_equal
from q2_types.per_sample_sequences import (
SingleLanePerSamplePairedEndFastqDirFmt,
SingleLanePerSampleSingleEndFastqDirFmt
SingleLanePerSampleSingleEndFastqDirFmt,
CasavaOneEightSingleLanePerSampleDirFmt
)

from q2_moshpit.kaiju.classification import (
_construct_feature_table, _rename_taxon, _encode_unclassified_ids,
_fix_id_types, _process_kaiju_reports, classify_kaiju
_fix_id_types, _process_kaiju_reports, _classify_kaiju, classify_kaiju
)
from qiime2.plugin.testing import TestPluginBase
from qiime2.plugins import moshpit


class TestKaijuClassification(TestPluginBase):
package = 'q2_moshpit.kaiju.tests'

def setUp(self):
super().setUp()
self.ctx = MagicMock()

self.mock_classify_kaiju = MagicMock(
side_effect=[("table1", "taxonomy1"), ("table2", "taxonomy2")]
)
self.ctx.get_action.side_effect = lambda domain, action_name: {
('moshpit', '_classify_kaiju'): self.mock_classify_kaiju,
('feature_table', 'merge'): self.mock_merge,
('feature_table', 'merge_taxa'): self.mock_merge_taxa,
('demux', 'partition_samples_single'): self.mock_partition_single,
('demux', 'partition_samples_paired'): self.mock_partition_paired,
}[domain, action_name]

# Additional action mocks
self.mock_merge = MagicMock(return_value=("merged_table",))
self.mock_merge_taxa = MagicMock(return_value=("merged_taxa",))
self.mock_partition_single = MagicMock()
self.mock_partition_paired = MagicMock()

self.mock_partition_single.return_value = {'part1': Mock()}
self.mock_partition_paired.return_value = {'part1': Mock()}
self.classify_kaiju = moshpit.pipelines.classify_kaiju
with open(self.get_data_path('taxa-map.json')) as f:
self.taxa_map = json.load(f)

Expand Down Expand Up @@ -186,7 +211,7 @@ def test_classify_kaiju_single(self, p1, p2):
p1.return_value = [pd.DataFrame(), pd.DataFrame()]

with patch("tempfile.TemporaryDirectory"):
classify_kaiju(
_classify_kaiju(
seqs=seqs, db=Mock(path=self.temp_dir.name),
z=3, a="greedy", e=2, m=10, s=66, evalue=0, x=False,
r="class", c=0.1, exp=True, u=True
Expand Down Expand Up @@ -220,7 +245,7 @@ def test_classify_kaiju_paired(self, p1, p2):
p1.return_value = [pd.DataFrame(), pd.DataFrame()]

with patch("tempfile.TemporaryDirectory"):
classify_kaiju(
_classify_kaiju(
seqs=seqs, db=Mock(path=self.temp_dir.name),
z=3, a="greedy", e=2, m=10, s=66, evalue=0, x=True,
r="class", c=0.1, exp=False, u=False
Expand Down Expand Up @@ -256,12 +281,156 @@ def test_classify_kaiju_exception(self, p1):
with self.assertRaisesRegex(
Exception, r"\(return code 1\), please inspect"
):
classify_kaiju(
_classify_kaiju(
seqs=seqs, db=Mock(path=self.temp_dir.name),
z=3, a="greedy", e=2, m=10, s=66, evalue=0, x=False,
r="class", c=0.1, exp=True, u=True
)

def test_classify_kaiju_single_partition_single_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[SequencesWithQuality]",
self.get_data_path("single-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_single.side_effect = [({0: "part1"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=1
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_single"
)

self.mock_partition_single.assert_called_once_with(fake_seqs, 1)
self.mock_partition_paired.assert_not_called()
self.mock_classify_kaiju.assert_called_once_with(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65, evalue=0.01,
x=True, r='species', c=0.0, exp=False, u=False
)
self.mock_merge.assert_called_once_with(["table1"])
self.mock_merge_taxa.assert_called_once_with(["taxonomy1"])
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)

def test_classify_kaiju_single_partition_paired_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[PairedEndSequencesWithQuality]",
self.get_data_path("paired-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_paired.side_effect = [({0: "part1"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=1
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_paired"
)

self.mock_partition_single.assert_not_called()
self.mock_partition_paired.assert_called_once_with(fake_seqs, 1)
self.mock_classify_kaiju.assert_called_once_with(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65, evalue=0.01,
x=True, r='species', c=0.0, exp=False, u=False
)
self.mock_merge.assert_called_once_with(["table1"])
self.mock_merge_taxa.assert_called_once_with(["taxonomy1"])
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)

def test_classify_kaiju_multiple_partitions_single_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[SequencesWithQuality]",
self.get_data_path("single-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_single.side_effect = [({0: "part1", 1: "part2"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=2
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_single"
)

self.mock_partition_single.assert_called_once_with(fake_seqs, 2)
self.mock_partition_paired.assert_not_called()
self.mock_classify_kaiju.assert_has_calls([
call(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
),
call(
"part2", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
)
])
self.mock_merge.assert_called_once_with(["table1", "table2"])
self.mock_merge_taxa.assert_called_once_with(
["taxonomy1", "taxonomy2"]
)
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)

def test_classify_kaiju_multiple_partitions_paired_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[PairedEndSequencesWithQuality]",
self.get_data_path("paired-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_paired.side_effect = [({0: "part1", 1: "part2"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=2
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_paired"
)

self.mock_partition_single.assert_not_called()
self.mock_partition_paired.assert_called_once_with(fake_seqs, 2)
self.mock_classify_kaiju.assert_has_calls([
call(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
),
call(
"part2", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
)
])
self.mock_merge.assert_called_once_with(["table1", "table2"])
self.mock_merge_taxa.assert_called_once_with(
["taxonomy1", "taxonomy2"]
)
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)


if __name__ == "__main__":
unittest.main()
Loading
Loading