Skip to content

Commit

Permalink
fix test failure
Browse files Browse the repository at this point in the history
  • Loading branch information
lizgehret committed Oct 3, 2024
1 parent 967a238 commit bf9cf08
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 16 deletions.
25 changes: 19 additions & 6 deletions rescript/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,9 @@ def _split_fasta(sequences, train_ids, test_ids):
return train_seqs, test_seqs


def evaluate_classifications(ctx,
expected_taxonomies,
observed_taxonomies,
labels=None):
lineplot = ctx.get_action('vizard', 'lineplot')
def _evaluate_classifications_stats(expected_taxonomies,
observed_taxonomies,
labels=None):
# Validate inputs.
if len(expected_taxonomies) != len(observed_taxonomies):
raise ValueError('Expected and Observed Taxonomies do not match. '
Expand Down Expand Up @@ -201,7 +199,22 @@ def evaluate_classifications(ctx,
# convert index to strings
precision_recall.index = pd.Index(
[str(i) for i in range(1, len(precision_recall.index) + 1)], name='id')
plots, = lineplot(metadata=q2.Metadata(precision_recall),

return q2.Metadata(precision_recall)


def evaluate_classifications(ctx,
expected_taxonomies,
observed_taxonomies,
labels=None):
lineplot = ctx.get_action('vizard', 'lineplot')

md = _evaluate_classifications_stats(
expected_taxonomies=expected_taxonomies,
observed_taxonomies=observed_taxonomies,
labels=labels)

plots, = lineplot(metadata=md,
x_measure='Level',
y_measure='F-Measure',
group_by='Dataset',
Expand Down
16 changes: 6 additions & 10 deletions rescript/tests/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

import os

from qiime2.plugin.testing import TestPluginBase
from qiime2.plugins import rescript
import qiime2
import pandas as pd
import pandas.testing as pdt

from rescript import cross_validate
from ..cross_validate import _evaluate_classifications_stats


import_data = qiime2.Artifact.import_data
Expand Down Expand Up @@ -92,22 +91,19 @@ def test_evaluate_fit_classifier(self):
pdt.assert_series_equal(
obs.view(pd.Series).sort_index(), exp_obs, check_names=False)

def test_evaluate_classifications(self):
def test_evaluate_classifications_stats(self):
# simulate predicted classifications at genus level
taxa = self.taxa_series.copy().apply(
lambda x: ';'.join(x.split(';')[:6]))
taxa = qiime2.Artifact.import_data('FeatureData[Taxonomy]', taxa)
# first round we just make sure this runs
rescript.actions.evaluate_classifications([self.taxa], [taxa])
_evaluate_classifications_stats([self.taxa], [taxa])
# now the same but input multiple times to test lists of inputs
vol, = rescript.actions.evaluate_classifications(
[self.taxa, taxa], [taxa, taxa])
obs = _evaluate_classifications_stats([self.taxa, taxa], [taxa, taxa])
# now inspect and validate the contents
# we inspect the second eval results to compare perfect match vs.
# simulated genus-level classification (when species are expected)
vol.export_data(self.temp_dir.name)
html_path = os.path.join(self.temp_dir.name, 'data.tsv')
vol = qiime2.Metadata.load(html_path).to_dataframe()
obs_df = obs.to_dataframe()
exp = pd.DataFrame({
'Level': {
'1': 1.0, '2': 2.0, '3': 3.0, '4': 4.0, '5': 5.0, '6': 6.0,
Expand All @@ -128,7 +124,7 @@ def test_evaluate_classifications(self):
'Dataset': {'1': '1', '2': '1', '3': '1', '4': '1', '5': '1',
'6': '1', '7': '1', '8': '2', '9': '2', '10': '2',
'11': '2', '12': '2', '13': '2'}}).sort_index()
pdt.assert_frame_equal(vol.sort_index(), exp, check_names=False)
pdt.assert_frame_equal(obs_df.sort_index(), exp, check_names=False)

def test_evaluate_classifications_mismatch_input_count(self):
with self.assertRaisesRegex(
Expand Down

0 comments on commit bf9cf08

Please sign in to comment.