diff --git a/q2_amr/card/normalization.py b/q2_amr/card/normalization.py index 9f955d0..d6dab5a 100644 --- a/q2_amr/card/normalization.py +++ b/q2_amr/card/normalization.py @@ -1,75 +1,31 @@ import os -import biom import pandas as pd from q2_types.feature_data import SequenceCharacteristicsDirectoryFormat from rnanorm import CPM, CTF, CUF, FPKM, TMM, TPM, UQ def normalize( - table: biom.Table, + table: pd.DataFrame, method: str, - m_trim: float = 0.3, - a_trim: float = 0.05, + m_trim: float = None, + a_trim: float = None, gene_length: SequenceCharacteristicsDirectoryFormat = None, ) -> pd.DataFrame: - # Create Dataframe with counts from biom.Table - counts = pd.DataFrame( - data=table.matrix_data.toarray(), - index=table.ids(axis="observation"), - columns=table.ids(axis="sample"), - ).T - - if method in ["tpm", "fpkm", "uq", "cuf", "cpm"]: - # Raise Error if m or a-trim parameters are given with methods TPM, FPKM, UQ, - # CPM or CUF - if m_trim != 0.3 or a_trim != 0.05: - raise ValueError( - "Parameters m-trim and a-trim can only be used with methods TMM and " - "CTF." - ) - - if method in ["tpm", "fpkm"]: - # Raise Error if gene-length is missing when using methods TPM or FPKM - if not gene_length: - raise ValueError("gene-length input is missing.") - - # Create pd.Series from gene_length input - lengths = pd.read_csv( - os.path.join(gene_length.path, "sequence_characteristics.tsv"), - sep="\t", - header=None, - names=["index", "values"], - index_col="index", - squeeze=True, - skiprows=1, - ) - - # Raise Error if there are genes in the counts that are not present in the - # gene length - if not set(counts.columns).issubset(set(lengths.index)): - only_in_counts = set(counts.columns) - set(lengths.index) - raise ValueError( - f"There are genes present in the FeatureTable that are not present " - f"in the gene-length input. Missing lengths for genes: " - f"{only_in_counts}" - ) - - # Define the methods TPM and FPKM with the gene length series as an input - methods = { - "tpm": TPM(gene_lengths=lengths), - "fpkm": FPKM(gene_lengths=lengths), - } - - if method in ["tmm", "uq", "cuf", "ctf", "cpm"]: - # Raise Error if gene-length is given when using methods TMM, UQ, CUF, CPM or - # CTF - if gene_length: - raise ValueError( - "gene-length input can only be used with FPKM and TPM methods." - ) - - # Define the methods TMM and CTF with parameters, also UQ, CPM and CUF + # Validate parameter combinations and set trim parameters + m_trim, a_trim = _validate_parameters(method, m_trim, a_trim, gene_length) + + # Process gene_lengths input and define methods that need gene_lengths input + if method in ["tpm", "fpkm"]: + lengths = _convert_lengths(table, gene_length) + + methods = { + "tpm": TPM(gene_lengths=lengths), + "fpkm": FPKM(gene_lengths=lengths), + } + + # Define remaining methods that don't need gene_lengths input + else: methods = { "tmm": TMM(m_trim=m_trim, a_trim=a_trim), "ctf": CTF(m_trim=m_trim, a_trim=a_trim), @@ -78,8 +34,57 @@ def normalize( "cpm": CPM(), } - # Run normalization method on count dataframe - normalized = methods[method].set_output(transform="pandas").fit_transform(counts) + # Run normalization method on frequency table + normalized = methods[method].set_output(transform="pandas").fit_transform(table) normalized.index.name = "sample_id" return normalized + + +def _validate_parameters(method, m_trim, a_trim, gene_length): + # Raise Error if gene-length is missing when using methods TPM or FPKM + if method in ["tpm", "fpkm"] and not gene_length: + raise ValueError("gene-length input is missing.") + + # Raise Error if gene-length is given when using methods TMM, UQ, CUF, CPM or CTF + if method in ["tmm", "uq", "cuf", "ctf", "cpm"] and gene_length: + raise ValueError( + "gene-length input can only be used with FPKM and TPM " "methods." + ) + + # Raise Error if m_trim or a_trim are given when not using methods TMM or CTF + if (method not in ["tmm", "ctf"]) and (m_trim is not None or a_trim is not None): + raise ValueError( + "Parameters m-trim and a-trim can only be used with methods TMM and " "CTF." + ) + + # Set m_trim and a_trim to their default values for methods TMM and CTF + if method in ["tmm", "ctf"]: + m_trim = 0.3 if m_trim is None else m_trim + a_trim = 0.05 if a_trim is None else a_trim + + return m_trim, a_trim + + +def _convert_lengths(table, gene_length): + # Read in table from sequence_characteristics.tsv as a pd.Series + lengths = pd.read_csv( + os.path.join(gene_length.path, "sequence_characteristics.tsv"), + sep="\t", + header=None, + names=["index", "values"], + index_col="index", + squeeze=True, + skiprows=1, + ) + + # Check if all gene IDs that are present in the table are also present in + # the lengths + if not set(table.columns).issubset(set(lengths.index)): + only_in_counts = set(table.columns) - set(lengths.index) + raise ValueError( + f"There are genes present in the FeatureTable that are not present " + f"in the gene-length input. Missing lengths for genes: " + f"{only_in_counts}" + ) + return lengths diff --git a/q2_amr/card/tests/data/feature-table.biom b/q2_amr/card/tests/data/feature-table.biom deleted file mode 100644 index 3a8fcee..0000000 Binary files a/q2_amr/card/tests/data/feature-table.biom and /dev/null differ diff --git a/q2_amr/card/tests/data/feature-table.tsv b/q2_amr/card/tests/data/feature-table.tsv new file mode 100644 index 0000000..712090b --- /dev/null +++ b/q2_amr/card/tests/data/feature-table.tsv @@ -0,0 +1,3 @@ +ID ARO:3000026|ID:377|Name:mepA|NCBI:AY661734.1 ARO:3000027|ID:1757|Name:emrA|NCBI:AP009048.1 +sample1 2.0 0.0 +sample2 2.0 0.0 diff --git a/q2_amr/card/tests/test_normalization.py b/q2_amr/card/tests/test_normalization.py index a5e1689..18eb790 100644 --- a/q2_amr/card/tests/test_normalization.py +++ b/q2_amr/card/tests/test_normalization.py @@ -2,11 +2,12 @@ import shutil from unittest.mock import MagicMock, patch -import biom +import pandas as pd +from pandas._testing import assert_series_equal from q2_types.feature_data import SequenceCharacteristicsDirectoryFormat from qiime2.plugin.testing import TestPluginBase -from q2_amr.card.normalization import normalize +from q2_amr.card.normalization import _convert_lengths, _validate_parameters, normalize class TestNormalize(TestPluginBase): @@ -14,51 +15,79 @@ class TestNormalize(TestPluginBase): @classmethod def setUpClass(cls): - # Mocking the biom.Table and gene_length class - cls.table = MagicMock() cls.gene_length = MagicMock() + cls.lengths = pd.Series( + { + "ARO:3000026|ID:377|Name:mepA|NCBI:AY661734.1": 1356.0, + "ARO:3000027|ID:1757|Name:emrA|NCBI:AP009048.1": 1173.0, + }, + name="values", + ) + cls.lengths.index.name = "index" - def test_tpm_fpkm_uq_cuf_with_invalid_m_a_trim(self): + def test_validate_parameters_uq_with_m_a_trim(self): # Test Error raised if gene-length is given with TMM method - expected_message = ( - "Parameters m-trim and a-trim can only be used with methods TMM and CTF." - ) - with self.assertRaises(ValueError) as cm: - normalize(self.table, "tpm", m_trim=0.2, a_trim=0.05) - self.assertEqual(str(cm.exception), expected_message) + with self.assertRaisesRegex( + ValueError, + "Parameters m-trim and a-trim can only " + "be used with methods TMM and CTF.", + ): + _validate_parameters("uq", 0.2, 0.05, None) - def test_tpm_fpkm_with_missing_gene_length(self): + def test_validate_parameters_tpm_missing_gene_length(self): # Test Error raised if gene-length is missing with TPM method - expected_message = "gene-length input is missing." - with self.assertRaises(ValueError) as cm: - normalize(self.table, "tpm") - self.assertEqual(str(cm.exception), expected_message) + with self.assertRaisesRegex(ValueError, "gene-length input is missing."): + _validate_parameters("tpm", None, None, None) - def test_tmm_uq_cuf_ctf_with_gene_length(self): + def test_validate_parameters_tmm_gene_length(self): # Test Error raised if gene-length is given with TMM method - expected_message = ( - "gene-length input can only be used with FPKM and TPM methods." + with self.assertRaisesRegex( + ValueError, "gene-length input can only be used with FPKM and TPM methods." + ): + _validate_parameters("tmm", None, None, gene_length=self.gene_length) + + def test_validate_parameters_default_m_a_trim(self): + # Test if m_trim and a_trim get set to default values if None + m_trim, a_trim = _validate_parameters("tmm", None, None, None) + self.assertEqual(m_trim, 0.3) + self.assertEqual(a_trim, 0.05) + + def test_validate_parameters_m_a_trim(self): + # Test if m_trim and a_trim are not modified if not None + m_trim, a_trim = _validate_parameters("tmm", 0.1, 0.06, None) + self.assertEqual(m_trim, 0.1) + self.assertEqual(a_trim, 0.06) + + def test_convert_lengths_gene_length(self): + # Test Error raised if gene-length is missing genes + gene_length = SequenceCharacteristicsDirectoryFormat() + shutil.copy( + self.get_data_path("sequence_characteristics.tsv"), gene_length.path + ) + table = pd.read_csv( + self.get_data_path("feature-table.tsv"), sep="\t", index_col="ID" ) - with self.assertRaises(ValueError) as cm: - normalize(self.table, "tmm", gene_length=self.gene_length) - self.assertEqual(str(cm.exception), expected_message) - def test_tpm_fpkm_with_short_gene_length(self): + obs = _convert_lengths(table, gene_length=gene_length) + assert_series_equal(obs, self.lengths) + + def test_convert_lengths_short_gene_length(self): # Test Error raised if gene-length is missing genes gene_length = SequenceCharacteristicsDirectoryFormat() shutil.copy( self.get_data_path("sequence_characteristics_short.tsv"), os.path.join(gene_length.path, "sequence_characteristics.tsv"), ) - table = biom.load_table(self.get_data_path("feature-table.biom")) - expected_message = ( + table = pd.read_csv( + self.get_data_path("feature-table.tsv"), sep="\t", index_col="ID" + ) + with self.assertRaisesRegex( + ValueError, "There are genes present in the FeatureTable that are not present " "in the gene-length input. Missing lengths for genes: " - "{'ARO:3000027|ID:1757|Name:emrA|NCBI:AP009048.1'}" - ) - with self.assertRaises(ValueError) as cm: - normalize(table, "tpm", gene_length=gene_length) - self.assertEqual(str(cm.exception), expected_message) + "{'ARO:3000027|ID:1757|Name:emrA|NCBI:AP009048.1'}", + ): + _convert_lengths(table, gene_length=gene_length) @patch("q2_amr.card.normalization.TPM") def test_tpm_fpkm_with_valid_inputs(self, mock_tpm): @@ -67,11 +96,15 @@ def test_tpm_fpkm_with_valid_inputs(self, mock_tpm): shutil.copy( self.get_data_path("sequence_characteristics.tsv"), gene_length.path ) - table = biom.load_table(self.get_data_path("feature-table.biom")) + table = pd.read_csv( + self.get_data_path("feature-table.tsv"), sep="\t", index_col="ID" + ) normalize(table=table, gene_length=gene_length, method="tpm") @patch("q2_amr.card.normalization.TMM") def test_tmm_uq_cuf_ctf_with_valid_inputs(self, mock_tmm): # Test valid inputs for TMM method - table = biom.load_table(self.get_data_path("feature-table.biom")) + table = pd.read_csv( + self.get_data_path("feature-table.tsv"), sep="\t", index_col="ID" + ) normalize(table=table, method="tmm", a_trim=0.06, m_trim=0.4) diff --git a/q2_amr/plugin_setup.py b/q2_amr/plugin_setup.py index 456776d..93dcd23 100644 --- a/q2_amr/plugin_setup.py +++ b/q2_amr/plugin_setup.py @@ -470,8 +470,8 @@ }, parameters={ "method": Str % Choices(["tpm", "fpkm", "tmm", "uq", "cuf", "ctf", "cpm"]), - "m_trim": Float % Range(0, 1, inclusive_start=True), - "a_trim": Float % Range(0, 1, inclusive_start=True), + "m_trim": Float % Range(0, 1, inclusive_start=True, inclusive_end=True), + "a_trim": Float % Range(0, 1, inclusive_start=True, inclusive_end=True), }, outputs=[("normalized_table", FeatureTable[Frequency])], input_descriptions={ @@ -485,9 +485,9 @@ "/12/2023-Normalizing-RNA-seq-data-in-Python-with-RNAnorm.pdf for " "more information on the methods.", "m_trim": "Two sided cutoff for M-values. Can only be used for methods TMM and " - "CTF.", + "CTF. (default = 0.3)", "a_trim": "Two sided cutoff for A-values. Can only be used for methods TMM and " - "CTF.", + "CTF. (default = 0.05)", }, output_descriptions={ "normalized_table": "Feature table normalized with specified " "method."