Skip to content

Commit

Permalink
added knn
Browse files Browse the repository at this point in the history
  • Loading branch information
BenKaehler committed Jul 30, 2018
1 parent bb3d44f commit 5c27a78
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 35 deletions.
16 changes: 10 additions & 6 deletions q2_clawback/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,21 @@
# ----------------------------------------------------------------------------

from ._version import get_versions
from ._clawback import (summarize_QIITA_metadata_category_and_contexts,
fetch_QIITA_samples,
from ._clawback import (summarize_Qiita_metadata_category_and_contexts,
fetch_Qiita_samples,
sequence_variants_from_samples,
generate_class_weights,
assemble_weights_from_QIITA)
assemble_weights_from_Qiita)
from ._knn import precalculate_nearest_neighbors, kNN_LOOCV_F_measures
from ._format import (PrecalculatedNearestNeighborsFormat,
PrecalculatedNearestNeighborsDirectoryFormat)
from ._type import PrecalculatedNearestNeighbors

__all__ = ['summarize_QIITA_metadata_category_and_contexts',
__all__ = ['summarize_Qiita_metadata_category_and_contexts',
'sequence_variants_from_samples',
'fetch_QIITA_samples',
'fetch_Qiita_samples',
'generate_class_weights',
'assemble_weights_from_QIITA']
'assemble_weights_from_Qiita']

__version__ = get_versions()['version']
del get_versions
14 changes: 7 additions & 7 deletions q2_clawback/_clawback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,31 +27,31 @@ def sequence_variants_from_samples(samples: biom.Table) -> DNAIterator:
return DNAIterator(seqs)


def _fetch_QIITA_summaries(category='sample_type'):
def _fetch_Qiita_summaries(category='sample_type'):
md = redbiom.fetch.category_sample_values(category)
counts = md.value_counts(ascending=False)
caches = redbiom.summarize.contexts()[['ContextName', 'SamplesWithData']]
caches = caches.sort_values(by='SamplesWithData', ascending=False)
return counts, caches


def summarize_QIITA_metadata_category_and_contexts(
def summarize_Qiita_metadata_category_and_contexts(
output_dir: str=None, category: str='sample_type'):
counts, caches = _fetch_QIITA_summaries(category=category)
counts, caches = _fetch_Qiita_summaries(category=category)
counts = counts.to_frame()
counts = DataFrame({category: counts.index, 'count': counts.values.T[0]},
columns=[category, 'count'])
sample_types = q2templates.df_to_html(counts, bold_rows=False, index=False)
contexts = q2templates.df_to_html(caches, index=False)
title = 'Available in QIITA'
title = 'Available in Qiita'
index = os.path.join(TEMPLATES, 'index.html')
q2templates.render(index, output_dir, context={
'title': title,
'sample_types': sample_types,
'contexts': contexts})


def fetch_QIITA_samples(metadata_value: list, context: str,
def fetch_Qiita_samples(metadata_value: list, context: str,
metadata_key: str='sample_type') -> biom.Table:
query = "where " + metadata_key + " == '"
query += ("' or " + metadata_key + " == '").join(metadata_value)
Expand Down Expand Up @@ -90,11 +90,11 @@ def generate_class_weights(
return biom.Table(weights[None].T, taxa, sample_ids=['Weight'])


def assemble_weights_from_QIITA(
def assemble_weights_from_Qiita(
ctx, classifier, reference_taxonomy, reference_sequences,
metadata_value, context, unobserved_weight=1e-6, normalise=False,
metadata_key='sample_type'):
samples, = ctx.get_action('clawback', 'fetch_QIITA_samples')(
samples, = ctx.get_action('clawback', 'fetch_Qiita_samples')(
metadata_value=metadata_value, context=context,
metadata_key=metadata_key)

Expand Down
40 changes: 40 additions & 0 deletions q2_clawback/_format.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2017, Ben Kaehler.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

import json

import qiime2.plugin.model as model
from qiime2.plugin import ValidationError


def _validate_precalculated_nearest_neighbors(flat_nn):
if not isinstance(flat_nn, dict):
raise ValidationError("Expected JSON-encoded dict")
if "neighbors" not in flat_nn or "taxonomies" not in flat_nn:
raise ValidationError('Expected dict to have keys '
'"neighbors" and "taxonomies"')
if not all(isinstance(v, list) for v in flat_nn.values()):
raise ValidationError("Expected dict of lists")
if len(flat_nn["neighbors"]) != len(flat_nn["taxonomies"]):
raise ValidationError('Expected neighbors and taxonomies '
'to have equal lengths')


class PrecalculatedNearestNeighborsFormat(model.TextFileFormat):
def validate(self, level):
with self.open() as fh:
try:
_validate_precalculated_nearest_neighbors(json.load(fh))
except json.JSONDecodeError as e:
raise ValidationError(e)


PrecalculatedNearestNeighborsDirectoryFormat = \
model.SingleFileDirectoryFormat(
'PrecalculatedNearestNeighborsDirectoryFormat',
'nearest_neighbors.json', PrecalculatedNearestNeighborsFormat)
120 changes: 120 additions & 0 deletions q2_clawback/_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2017, Ben Kaehler.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

import os
import json
from collections import Counter

import pkg_resources
from imblearn.under_sampling import ClusterCentroids
from sklearn.base import TransformerMixin
from sklearn.neighbors.base import KNeighborsMixin
from pandas import Series, DataFrame
from sklearn.metrics import f1_score
import q2templates
from q2_types.feature_data import DNAIterator
from q2_feature_classifier.classifier import pipeline_from_spec
from q2_feature_classifier._skl import _extract_reads

TEMPLATES = pkg_resources.resource_filename('q2_clawback', 'assets')

_default_feature_extractor = (
'[["ext", {'
'"analyzer": "char_wb", '
'"__type__": "feature_extraction.text.HashingVectorizer", '
'"n_features": 8192, '
'"strip_accents": null, '
'"ngram_range": [7, 7], '
'"alternate_sign": false'
'}]]'
)

_default_knn_classifier = \
'[["cls", {"__type__": "neighbors.NearestNeighbors", "n_neighbors": 11}]]'


def precalculate_nearest_neighbors(
reference_taxonomy: Series, reference_sequences: DNAIterator,
max_centroids_per_class: int=10,
feature_extractor_specification: str=_default_feature_extractor,
knn_classifier_specification: str=_default_knn_classifier,
n_jobs: int=1, random_state: int=42) -> dict:
spec = json.loads(feature_extractor_specification)
feat_ext = pipeline_from_spec(spec)
if not isinstance(feat_ext.steps[-1][-1], TransformerMixin):
raise ValueError('feature_extractor_specification must specify a '
'transformer')
spec = json.loads(knn_classifier_specification)
nn = pipeline_from_spec(spec)
if not isinstance(nn.steps[-1][-1], KNeighborsMixin):
raise ValueError('knn_classifier_specification must specifiy a '
'KNeighbors classifier')

seq_ids, X = _extract_reads(reference_sequences)
data = [(reference_taxonomy[s], x)
for s, x in zip(seq_ids, X) if s in reference_taxonomy]
y, X = list(zip(*data))
X = feat_ext.transform(X)

if max_centroids_per_class > 0:
class_counts = Counter(y)
undersample_classes = {t: max_centroids_per_class
for t, c in class_counts.items()
if c > max_centroids_per_class}
cc = ClusterCentroids(random_state=random_state, n_jobs=n_jobs,
ratio=undersample_classes, voting='hard')
X_resampled, y_resampled = cc.fit_sample(X, y)
else:
X_resampled, y_resampled = X, y

if 'n_jobs' in nn.steps[-1][-1].get_params():
nn.steps[-1][-1].set_params(n_jobs=n_jobs)
nn.fit(X_resampled)
nn = nn.steps[-1][-1]
if n_jobs != 1 and hasattr(X_resampled, 'todense'):
indices = nn.kneighbors(X_resampled.todense(), return_distance=False)
else:
indices = nn.kneighbors(X_resampled, return_distance=False)
return {'neighbors': indices.tolist(), 'taxonomies': y_resampled.tolist()}


def _loocv(y, indices, weights, uniform_prior=False):
yfreq = Counter(y)
if uniform_prior:
sample_weights = [1./len(yfreq)/yfreq[t] for t in y]
else:
if yfreq.keys() != weights.keys():
raise ValueError('Nearest neighbors and weights were calculated '
'using different reference data sets')
sample_weights = [weights[t]/yfreq[t] for t in y]
pred = []
for row in indices:
vote = Counter()
for ix in row[1:]:
vote[y[ix]] += sample_weights[ix]
pred.append(vote.most_common()[0][0])
if uniform_prior:
sample_weights = [weights[t]/yfreq[t] for t in y]
return f1_score(y, pred, average='weighted', sample_weight=sample_weights)


def kNN_LOOCV_F_measures(output_dir: str,
nearest_neighbors: dict, class_weight: DataFrame):
y = nearest_neighbors['taxonomies']
indices = nearest_neighbors['neighbors']
weights = class_weight.T['Weight'].to_dict()
uniform = _loocv(y, indices, weights, True)
bespoke = _loocv(y, indices, weights)
index = os.path.join(TEMPLATES, 'index.html')
f_measures = DataFrame({'F-measure': [bespoke, uniform, bespoke-uniform]},
index=['Weighted', 'Uniform', 'Difference'])
f_measures = q2templates.df_to_html(f_measures)
q2templates.render(index, output_dir, context={
'title': 'Indicators of Taxonomic Weight Importance',
'f_measures': f_measures,
})
28 changes: 28 additions & 0 deletions q2_clawback/_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2017, Ben Kaehler.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

import json

from .plugin_setup import plugin
from ._format import PrecalculatedNearestNeighborsFormat, \
_validate_precalculated_nearest_neighbors


@plugin.register_transformer
def _1(data: dict) -> PrecalculatedNearestNeighborsFormat:
_validate_precalculated_nearest_neighbors(data)
ff = PrecalculatedNearestNeighborsFormat()
with ff.open() as fh:
json.dump(data, fh)
return ff


@plugin.register_transformer
def _2(ff: PrecalculatedNearestNeighborsFormat) -> dict:
with ff.open() as fh:
return json.load(fh)
11 changes: 11 additions & 0 deletions q2_clawback/_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# ----------------------------------------------------------------------------
# Copyright (c) 2016-2017, Ben Kaehler.
#
# Distributed under the terms of the Modified BSD License.
#
# The full license is in the file LICENSE, distributed with this software.
# ----------------------------------------------------------------------------

from qiime2.plugin import SemanticType

PrecalculatedNearestNeighbors = SemanticType('PrecalculatedNearestNeighbors')
32 changes: 22 additions & 10 deletions q2_clawback/assets/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,28 @@
{% block title %}q2-clawback : {{ title }}{% endblock %}

{% block content %}
<h1>{{ title }}</h1>

<div class="row">
<h1>Metadata Values</h1>
<div class="col-lg-12">
{{ sample_types }}
</div>
<h1>Contexts</h1>
<div class="col-lg-12">
{{ contexts }}
</div>
</div>
{% if sample_types %}
<div class="row">
<h1>Metadata Values</h1>
<div class="col-lg-12">
{{ sample_types }}
</div>
<h1>Contexts</h1>
<div class="col-lg-12">
{{ contexts }}
</div>
</div>
{% endif %}

{% if f_measures %}
<div class="row">
<div class="col-lg-12">
{{ f_measures }}
</div>
</div>

{% endif %}

{% endblock %}
Loading

0 comments on commit 5c27a78

Please sign in to comment.