Skip to content

Commit

Permalink
ENH: parallelize Kaiju classification (bokulich-lab#200)
Browse files Browse the repository at this point in the history
Co-authored-by: Christos Konstantinos Matzoros <[email protected]>
  • Loading branch information
misialq and ChristosMatzoros authored Oct 3, 2024
1 parent a6e3a4d commit 428fcae
Show file tree
Hide file tree
Showing 10 changed files with 292 additions and 41 deletions.
4 changes: 2 additions & 2 deletions q2_moshpit/kaiju/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@
# ----------------------------------------------------------------------------

from .database import fetch_kaiju_db
from .classification import classify_kaiju
from .classification import classify_kaiju, _classify_kaiju

__all__ = ["fetch_kaiju_db", "classify_kaiju"]
__all__ = ["fetch_kaiju_db", "classify_kaiju", "_classify_kaiju"]
55 changes: 52 additions & 3 deletions q2_moshpit/kaiju/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
from q2_types.per_sample_sequences import (
SingleLanePerSamplePairedEndFastqDirFmt,
SingleLanePerSampleSingleEndFastqDirFmt,
SequencesWithQuality,
PairedEndSequencesWithQuality,
)

from q2_moshpit._utils import run_command
from q2_types.kaiju import KaijuDBDirectoryFormat
from q2_types.sample_data import SampleData

DEFAULT_PREFIXES = ["d__", "p__", "c__", "o__", "f__", "g__", "s__", "ssp__"]

Expand Down Expand Up @@ -195,7 +198,7 @@ def _process_kaiju_reports(tmpdir, all_args):
return _construct_feature_table(table_fp)


def _classify_kaiju(
def _classify_kaiju_helper(
manifest: pd.DataFrame, all_args: dict
) -> (pd.DataFrame, pd.DataFrame):
"""
Expand Down Expand Up @@ -255,7 +258,7 @@ def _classify_kaiju(
return table, taxonomy


def classify_kaiju(
def _classify_kaiju(
seqs: Union[
SingleLanePerSamplePairedEndFastqDirFmt,
SingleLanePerSampleSingleEndFastqDirFmt,
Expand All @@ -274,4 +277,50 @@ def classify_kaiju(
u: bool = False,
) -> (pd.DataFrame, pd.DataFrame):
manifest: pd.DataFrame = seqs.manifest.view(pd.DataFrame)
return _classify_kaiju(manifest, dict(locals().items()))
return _classify_kaiju_helper(manifest, dict(locals().items()))


def classify_kaiju(
ctx,
seqs,
db,
z=1,
a="greedy",
e=3,
m=11,
s=65,
evalue=0.01,
x=True,
r="species",
c=0.0,
exp=False,
u=False,
num_partitions=None
):
kwargs = {k: v for k, v in locals().items()
if k not in ["seqs", "db", "ctx", "num_partitions"]}

_classify_kaiju = ctx.get_action("moshpit", "_classify_kaiju")
collate_feature_tables = ctx.get_action("feature_table", "merge")
collate_taxonomies = ctx.get_action("feature_table", "merge_taxa")

if seqs.type <= SampleData[SequencesWithQuality]:
partition_method = ctx.get_action("demux", "partition_samples_single")
elif seqs.type <= SampleData[PairedEndSequencesWithQuality]:
partition_method = ctx.get_action("demux", "partition_samples_paired")
else:
raise NotImplementedError()

(partitioned_seqs,) = partition_method(seqs, num_partitions)

tables = []
taxonomies = []
for seq in partitioned_seqs.values():
(table, taxonomy) = _classify_kaiju(seq, db, **kwargs)
tables.append(table)
taxonomies.append(taxonomy)

(combined_table,) = collate_feature_tables(tables)
(collated_taxonomy,) = collate_taxonomies(taxonomies)

return combined_table, collated_taxonomy
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
181 changes: 175 additions & 6 deletions q2_moshpit/kaiju/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,53 @@
import os
import unittest
from subprocess import CalledProcessError
from unittest.mock import patch, Mock, ANY
from unittest.mock import patch, Mock, ANY, MagicMock, call

import numpy as np
import pandas as pd
import qiime2
from pandas._testing import assert_frame_equal
from q2_types.per_sample_sequences import (
SingleLanePerSamplePairedEndFastqDirFmt,
SingleLanePerSampleSingleEndFastqDirFmt
SingleLanePerSampleSingleEndFastqDirFmt,
CasavaOneEightSingleLanePerSampleDirFmt
)

from q2_moshpit.kaiju.classification import (
_construct_feature_table, _rename_taxon, _encode_unclassified_ids,
_fix_id_types, _process_kaiju_reports, classify_kaiju
_fix_id_types, _process_kaiju_reports, _classify_kaiju, classify_kaiju
)
from qiime2.plugin.testing import TestPluginBase
from qiime2.plugins import moshpit


class TestKaijuClassification(TestPluginBase):
package = 'q2_moshpit.kaiju.tests'

def setUp(self):
super().setUp()
self.ctx = MagicMock()

self.mock_classify_kaiju = MagicMock(
side_effect=[("table1", "taxonomy1"), ("table2", "taxonomy2")]
)
self.ctx.get_action.side_effect = lambda domain, action_name: {
('moshpit', '_classify_kaiju'): self.mock_classify_kaiju,
('feature_table', 'merge'): self.mock_merge,
('feature_table', 'merge_taxa'): self.mock_merge_taxa,
('demux', 'partition_samples_single'): self.mock_partition_single,
('demux', 'partition_samples_paired'): self.mock_partition_paired,
}[domain, action_name]

# Additional action mocks
self.mock_merge = MagicMock(return_value=("merged_table",))
self.mock_merge_taxa = MagicMock(return_value=("merged_taxa",))
self.mock_partition_single = MagicMock()
self.mock_partition_paired = MagicMock()

self.mock_partition_single.return_value = {'part1': Mock()}
self.mock_partition_paired.return_value = {'part1': Mock()}
self.classify_kaiju = moshpit.pipelines.classify_kaiju
with open(self.get_data_path('taxa-map.json')) as f:
self.taxa_map = json.load(f)

Expand Down Expand Up @@ -186,7 +211,7 @@ def test_classify_kaiju_single(self, p1, p2):
p1.return_value = [pd.DataFrame(), pd.DataFrame()]

with patch("tempfile.TemporaryDirectory"):
classify_kaiju(
_classify_kaiju(
seqs=seqs, db=Mock(path=self.temp_dir.name),
z=3, a="greedy", e=2, m=10, s=66, evalue=0, x=False,
r="class", c=0.1, exp=True, u=True
Expand Down Expand Up @@ -220,7 +245,7 @@ def test_classify_kaiju_paired(self, p1, p2):
p1.return_value = [pd.DataFrame(), pd.DataFrame()]

with patch("tempfile.TemporaryDirectory"):
classify_kaiju(
_classify_kaiju(
seqs=seqs, db=Mock(path=self.temp_dir.name),
z=3, a="greedy", e=2, m=10, s=66, evalue=0, x=True,
r="class", c=0.1, exp=False, u=False
Expand Down Expand Up @@ -256,12 +281,156 @@ def test_classify_kaiju_exception(self, p1):
with self.assertRaisesRegex(
Exception, r"\(return code 1\), please inspect"
):
classify_kaiju(
_classify_kaiju(
seqs=seqs, db=Mock(path=self.temp_dir.name),
z=3, a="greedy", e=2, m=10, s=66, evalue=0, x=False,
r="class", c=0.1, exp=True, u=True
)

def test_classify_kaiju_single_partition_single_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[SequencesWithQuality]",
self.get_data_path("single-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_single.side_effect = [({0: "part1"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=1
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_single"
)

self.mock_partition_single.assert_called_once_with(fake_seqs, 1)
self.mock_partition_paired.assert_not_called()
self.mock_classify_kaiju.assert_called_once_with(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65, evalue=0.01,
x=True, r='species', c=0.0, exp=False, u=False
)
self.mock_merge.assert_called_once_with(["table1"])
self.mock_merge_taxa.assert_called_once_with(["taxonomy1"])
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)

def test_classify_kaiju_single_partition_paired_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[PairedEndSequencesWithQuality]",
self.get_data_path("paired-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_paired.side_effect = [({0: "part1"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=1
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_paired"
)

self.mock_partition_single.assert_not_called()
self.mock_partition_paired.assert_called_once_with(fake_seqs, 1)
self.mock_classify_kaiju.assert_called_once_with(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65, evalue=0.01,
x=True, r='species', c=0.0, exp=False, u=False
)
self.mock_merge.assert_called_once_with(["table1"])
self.mock_merge_taxa.assert_called_once_with(["taxonomy1"])
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)

def test_classify_kaiju_multiple_partitions_single_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[SequencesWithQuality]",
self.get_data_path("single-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_single.side_effect = [({0: "part1", 1: "part2"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=2
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_single"
)

self.mock_partition_single.assert_called_once_with(fake_seqs, 2)
self.mock_partition_paired.assert_not_called()
self.mock_classify_kaiju.assert_has_calls([
call(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
),
call(
"part2", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
)
])
self.mock_merge.assert_called_once_with(["table1", "table2"])
self.mock_merge_taxa.assert_called_once_with(
["taxonomy1", "taxonomy2"]
)
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)

def test_classify_kaiju_multiple_partitions_paired_end(self):
fake_seqs = qiime2.Artifact.import_data(
"SampleData[PairedEndSequencesWithQuality]",
self.get_data_path("paired-end-casava"),
CasavaOneEightSingleLanePerSampleDirFmt
)
fake_db = Mock()

self.mock_partition_paired.side_effect = [({0: "part1", 1: "part2"},)]

out_table, out_taxonomy = classify_kaiju(
self.ctx, fake_seqs, fake_db, num_partitions=2
)

self.ctx.get_action.assert_any_call("moshpit", "_classify_kaiju")
self.ctx.get_action.assert_any_call("feature_table", "merge")
self.ctx.get_action.assert_any_call("feature_table", "merge_taxa")
self.ctx.get_action.assert_any_call(
"demux", "partition_samples_paired"
)

self.mock_partition_single.assert_not_called()
self.mock_partition_paired.assert_called_once_with(fake_seqs, 2)
self.mock_classify_kaiju.assert_has_calls([
call(
"part1", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
),
call(
"part2", fake_db, z=1, a='greedy', e=3, m=11, s=65,
evalue=0.01, x=True, r='species', c=0.0, exp=False, u=False
)
])
self.mock_merge.assert_called_once_with(["table1", "table2"])
self.mock_merge_taxa.assert_called_once_with(
["taxonomy1", "taxonomy2"]
)
self.assertEqual("merged_table", out_table)
self.assertEqual("merged_taxa", out_taxonomy)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 428fcae

Please sign in to comment.