Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IMP: replace volatility with lineplot for evaluate_* visualizers #204

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ test:
requires:
- pytest
- q2-types >={{ q2_types }}
- q2-longitudinal >={{ q2_longitudinal }}
- q2-vizard >={{ q2_vizard }}
- q2-feature-classifier >={{ q2_feature_classifier }}
- qiime2 >={{ qiime2 }}

Expand Down
32 changes: 23 additions & 9 deletions rescript/cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,9 @@ def _split_fasta(sequences, train_ids, test_ids):
return train_seqs, test_seqs


def evaluate_classifications(ctx,
expected_taxonomies,
observed_taxonomies,
labels=None):
volatility = ctx.get_action('longitudinal', 'volatility')
def _evaluate_classifications_stats(expected_taxonomies,
observed_taxonomies,
labels=None):
# Validate inputs.
if len(expected_taxonomies) != len(observed_taxonomies):
raise ValueError('Expected and Observed Taxonomies do not match. '
Expand Down Expand Up @@ -201,10 +199,26 @@ def evaluate_classifications(ctx,
# convert index to strings
precision_recall.index = pd.Index(
[str(i) for i in range(1, len(precision_recall.index) + 1)], name='id')
plots, = volatility(metadata=q2.Metadata(precision_recall),
state_column='Level',
default_group_column='Dataset',
default_metric='F-Measure')

return q2.Metadata(precision_recall)


def evaluate_classifications(ctx,
expected_taxonomies,
observed_taxonomies,
labels=None):
lineplot = ctx.get_action('vizard', 'lineplot')

md = _evaluate_classifications_stats(
expected_taxonomies=expected_taxonomies,
observed_taxonomies=observed_taxonomies,
labels=labels)

plots, = lineplot(metadata=md,
x_measure='Level',
y_measure='F-Measure',
group_by='Dataset',
title='RESCRIPt Evaluate Classifications')
return plots


Expand Down
12 changes: 7 additions & 5 deletions rescript/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,13 @@ def evaluate_taxonomy(ctx,
results.index = pd.Index(
[str(i) for i in range(1, len(results.index) + 1)], name='id')
results = q2.Metadata(results)
volatility = ctx.get_action('longitudinal', 'volatility')
plots, = volatility(metadata=results,
state_column='Level',
default_group_column='Dataset',
default_metric='Taxonomic Entropy')

lineplot = ctx.get_action('vizard', 'lineplot')
plots, = lineplot(metadata=results,
x_measure='Level',
y_measure='Taxonomic Entropy',
group_by='Dataset',
title='RESCRIPt Evaluate Taxonomy')
return plots


Expand Down
6 changes: 3 additions & 3 deletions rescript/plugin_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
'CC BY-SA 4.0. To learn more, please visit https://unite.ut.ee/cite.php '
'and https://creativecommons.org/licenses/by-sa/4.0/.')

VOLATILITY_PLOT_XAXIS_INTERPRETATION = (
LINEPLOT_XAXIS_INTERPRETATION = (
'The x-axis in these plots represents the taxonomic '
'levels present in the input taxonomies so are labeled numerically '
'instead of by rank, but typically for 7-level taxonomies these will '
Expand Down Expand Up @@ -225,7 +225,7 @@
'sets of true taxonomic labels to the predicted taxonomies for the '
'same set(s) of features. Output an interactive line plot of '
'classification accuracy for each pair of expected/observed '
'taxonomies. ' + VOLATILITY_PLOT_XAXIS_INTERPRETATION),
'taxonomies. ' + LINEPLOT_XAXIS_INTERPRETATION),
citations=[citations['bokulich2018optimizing'],
citations['bokulich2017q2']]
)
Expand Down Expand Up @@ -395,7 +395,7 @@
'unique labels, taxonomic entropy, and the number of features that '
'are (un)classified at each taxonomic level. This action is useful '
'for both reference taxonomies and classification results. ' +
VOLATILITY_PLOT_XAXIS_INTERPRETATION),
LINEPLOT_XAXIS_INTERPRETATION),
citations=[citations['bokulich2017q2']]
)

Expand Down
16 changes: 6 additions & 10 deletions rescript/tests/test_cross_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,14 @@
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

import os

from qiime2.plugin.testing import TestPluginBase
from qiime2.plugins import rescript
import qiime2
import pandas as pd
import pandas.testing as pdt

from rescript import cross_validate
from ..cross_validate import _evaluate_classifications_stats


import_data = qiime2.Artifact.import_data
Expand Down Expand Up @@ -92,22 +91,19 @@ def test_evaluate_fit_classifier(self):
pdt.assert_series_equal(
obs.view(pd.Series).sort_index(), exp_obs, check_names=False)

def test_evaluate_classifications(self):
def test_evaluate_classifications_stats(self):
# simulate predicted classifications at genus level
taxa = self.taxa_series.copy().apply(
lambda x: ';'.join(x.split(';')[:6]))
taxa = qiime2.Artifact.import_data('FeatureData[Taxonomy]', taxa)
# first round we just make sure this runs
rescript.actions.evaluate_classifications([self.taxa], [taxa])
_evaluate_classifications_stats([self.taxa], [taxa])
# now the same but input multiple times to test lists of inputs
vol, = rescript.actions.evaluate_classifications(
[self.taxa, taxa], [taxa, taxa])
obs = _evaluate_classifications_stats([self.taxa, taxa], [taxa, taxa])
# now inspect and validate the contents
# we inspect the second eval results to compare perfect match vs.
# simulated genus-level classification (when species are expected)
vol.export_data(self.temp_dir.name)
html_path = os.path.join(self.temp_dir.name, 'data.tsv')
vol = qiime2.Metadata.load(html_path).to_dataframe()
obs_df = obs.to_dataframe()
exp = pd.DataFrame({
'Level': {
'1': 1.0, '2': 2.0, '3': 3.0, '4': 4.0, '5': 5.0, '6': 6.0,
Expand All @@ -128,7 +124,7 @@ def test_evaluate_classifications(self):
'Dataset': {'1': '1', '2': '1', '3': '1', '4': '1', '5': '1',
'6': '1', '7': '1', '8': '2', '9': '2', '10': '2',
'11': '2', '12': '2', '13': '2'}}).sort_index()
pdt.assert_frame_equal(vol.sort_index(), exp, check_names=False)
pdt.assert_frame_equal(obs_df.sort_index(), exp, check_names=False)

def test_evaluate_classifications_mismatch_input_count(self):
with self.assertRaisesRegex(
Expand Down
Loading