Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
misialq committed Dec 13, 2024
1 parent 3892a7f commit a18c138
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 19 deletions.
24 changes: 10 additions & 14 deletions q2_moshpit/dereplication/derep.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def _generate_pa_table(
return presence_absence


def _get_representatives(mags, metadata, column_name, bin_clusters, find_max):
def _get_representatives(mags, metadata, column, bin_clusters, find_max):
"""
This function identifies the representative bin for each cluster of bins.
If metadata is provided, the selection is based on a numerical metadata
Expand All @@ -268,7 +268,7 @@ def _get_representatives(mags, metadata, column_name, bin_clusters, find_max):
Args:
mags: A MultiMAGSequencesDirFmt object containing all bins.
metadata: Qiime metadata.
column_name: Name of a column in metadata.
column: Name of a column in metadata.
bin_clusters: A list of lists where each inner list contains the IDs
of bins grouped by similarity.
find_max: Boolean, if true the bin with the highest value in the
Expand All @@ -279,15 +279,14 @@ def _get_representatives(mags, metadata, column_name, bin_clusters, find_max):
A list of representative bin IDs, one for each cluster.
"""
bin_lengths = _get_bin_lengths(mags)
method = pd.Series.max if find_max else pd.Series.min

# Choose by metadata
if metadata is not None:
try:
metadata_column = (
metadata.to_dataframe()[column_name].astype(float)
)
metadata_col = metadata.to_dataframe()[column].astype(float)
except KeyError:
raise KeyError(f'The column "{column_name}" does not exist '
raise KeyError(f'The column "{column}" does not exist '
f'in the metadata.')
except ValueError:
raise ValueError('The specified metadata column has to be '
Expand All @@ -296,17 +295,14 @@ def _get_representatives(mags, metadata, column_name, bin_clusters, find_max):
representative_bins = []
for bins in bin_clusters:
# Get bin IDs with the max or min column values in cluster
max_min_value_bins = (
values := metadata_column[bins]
)[values == (values.max() if find_max else values.min())].index
values = metadata_col[bins]
selected_bins = values[values == method(values)].index

# If there's a tie, resolve by selecting the longest bin
if len(max_min_value_bins) > 1:
representative_bins.append(
bin_lengths[max_min_value_bins].idxmax()
)
if len(selected_bins) > 1:
representative_bins.append(bin_lengths[selected_bins].idxmax())
else:
representative_bins.append(max_min_value_bins[0])
representative_bins.append(selected_bins[0])

# Choose by length
else:
Expand Down
10 changes: 5 additions & 5 deletions q2_moshpit/dereplication/tests/test_dereplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_get_representatives_metadata_max_value(self):
obs = _get_representatives(
mags=self.bins,
metadata=qiime2.Metadata(self.busco_results),
column_name="complete",
column="complete",
bin_clusters=self.clusters_99,
find_max=True
)
Expand All @@ -185,7 +185,7 @@ def test_get_representatives_metadata_min_value(self):
obs = _get_representatives(
mags=self.bins,
metadata=qiime2.Metadata(self.busco_results),
column_name="complete",
column="complete",
bin_clusters=self.clusters_99,
find_max=False
)
Expand All @@ -200,7 +200,7 @@ def test_get_representatives_length(self):
obs = _get_representatives(
mags=self.bins,
metadata=None,
column_name=None,
column=None,
bin_clusters=self.clusters_99,
find_max=True
)
Expand All @@ -216,7 +216,7 @@ def test_get_representatives_key_error(self):
_get_representatives(
mags=self.bins,
metadata=qiime2.Metadata(self.busco_results),
column_name="version",
column="version",
bin_clusters=self.clusters_99,
find_max=True
)
Expand All @@ -226,7 +226,7 @@ def test_get_representatives_value_error(self):
_get_representatives(
mags=self.bins,
metadata=qiime2.Metadata(self.busco_results),
column_name="dataset",
column="dataset",
bin_clusters=self.clusters_99,
find_max=True
)
Expand Down

0 comments on commit a18c138

Please sign in to comment.