From bf9cf08badb6a0851aa483c4b4fa48e961b8a6d4 Mon Sep 17 00:00:00 2001 From: Liz Gehret Date: Thu, 3 Oct 2024 10:19:17 -0700 Subject: [PATCH] fix test failure --- rescript/cross_validate.py | 25 +++++++++++++++++++------ rescript/tests/test_cross_validate.py | 16 ++++++---------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/rescript/cross_validate.py b/rescript/cross_validate.py index 5f49357..7b36b3e 100644 --- a/rescript/cross_validate.py +++ b/rescript/cross_validate.py @@ -162,11 +162,9 @@ def _split_fasta(sequences, train_ids, test_ids): return train_seqs, test_seqs -def evaluate_classifications(ctx, - expected_taxonomies, - observed_taxonomies, - labels=None): - lineplot = ctx.get_action('vizard', 'lineplot') +def _evaluate_classifications_stats(expected_taxonomies, + observed_taxonomies, + labels=None): # Validate inputs. if len(expected_taxonomies) != len(observed_taxonomies): raise ValueError('Expected and Observed Taxonomies do not match. ' @@ -201,7 +199,22 @@ def evaluate_classifications(ctx, # convert index to strings precision_recall.index = pd.Index( [str(i) for i in range(1, len(precision_recall.index) + 1)], name='id') - plots, = lineplot(metadata=q2.Metadata(precision_recall), + + return q2.Metadata(precision_recall) + + +def evaluate_classifications(ctx, + expected_taxonomies, + observed_taxonomies, + labels=None): + lineplot = ctx.get_action('vizard', 'lineplot') + + md = _evaluate_classifications_stats( + expected_taxonomies=expected_taxonomies, + observed_taxonomies=observed_taxonomies, + labels=labels) + + plots, = lineplot(metadata=md, x_measure='Level', y_measure='F-Measure', group_by='Dataset', diff --git a/rescript/tests/test_cross_validate.py b/rescript/tests/test_cross_validate.py index 1c85575..2e92d07 100644 --- a/rescript/tests/test_cross_validate.py +++ b/rescript/tests/test_cross_validate.py @@ -6,8 +6,6 @@ # The full license is in the file LICENSE, distributed with this software. # ---------------------------------------------------------------------------- -import os - from qiime2.plugin.testing import TestPluginBase from qiime2.plugins import rescript import qiime2 @@ -15,6 +13,7 @@ import pandas.testing as pdt from rescript import cross_validate +from ..cross_validate import _evaluate_classifications_stats import_data = qiime2.Artifact.import_data @@ -92,22 +91,19 @@ def test_evaluate_fit_classifier(self): pdt.assert_series_equal( obs.view(pd.Series).sort_index(), exp_obs, check_names=False) - def test_evaluate_classifications(self): + def test_evaluate_classifications_stats(self): # simulate predicted classifications at genus level taxa = self.taxa_series.copy().apply( lambda x: ';'.join(x.split(';')[:6])) taxa = qiime2.Artifact.import_data('FeatureData[Taxonomy]', taxa) # first round we just make sure this runs - rescript.actions.evaluate_classifications([self.taxa], [taxa]) + _evaluate_classifications_stats([self.taxa], [taxa]) # now the same but input multiple times to test lists of inputs - vol, = rescript.actions.evaluate_classifications( - [self.taxa, taxa], [taxa, taxa]) + obs = _evaluate_classifications_stats([self.taxa, taxa], [taxa, taxa]) # now inspect and validate the contents # we inspect the second eval results to compare perfect match vs. # simulated genus-level classification (when species are expected) - vol.export_data(self.temp_dir.name) - html_path = os.path.join(self.temp_dir.name, 'data.tsv') - vol = qiime2.Metadata.load(html_path).to_dataframe() + obs_df = obs.to_dataframe() exp = pd.DataFrame({ 'Level': { '1': 1.0, '2': 2.0, '3': 3.0, '4': 4.0, '5': 5.0, '6': 6.0, @@ -128,7 +124,7 @@ def test_evaluate_classifications(self): 'Dataset': {'1': '1', '2': '1', '3': '1', '4': '1', '5': '1', '6': '1', '7': '1', '8': '2', '9': '2', '10': '2', '11': '2', '12': '2', '13': '2'}}).sort_index() - pdt.assert_frame_equal(vol.sort_index(), exp, check_names=False) + pdt.assert_frame_equal(obs_df.sort_index(), exp, check_names=False) def test_evaluate_classifications_mismatch_input_count(self): with self.assertRaisesRegex(