Skip to content

Commit

Permalink
changed partition to work with new sample_dict function changed tests…
Browse files Browse the repository at this point in the history
… for more readability
  • Loading branch information
VinzentRisch committed Apr 11, 2024
1 parent e1ae4ba commit e78db60
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 73 deletions.
60 changes: 20 additions & 40 deletions q2_amr/card/partition.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import itertools
import os
import warnings
from typing import Union

import numpy as np
from qiime2.util import duplicate

from q2_amr.types import (
Expand Down Expand Up @@ -130,23 +130,10 @@ def _partition_annotations(
num_partitions: int = None,
):
partitioned_annotations = {}
annotations_all = []

# Add one tuple with sample id MAG id and full paths to annotation files to
# annotations_all per annotation file
if isinstance(annotations, CARDAnnotationDirectoryFormat):
for sample_id, mag in annotations.sample_dict().items():
for mag_id, annotation_fp_list in mag.items():
for annotation_fp in annotation_fp_list:
annotations_all.append((sample_id, mag_id, annotation_fp))

else:
for sample_id, annotation_fp_list in annotations.sample_dict().items():
for annotation_fp in annotation_fp_list:
annotations_all.append((sample_id, annotation_fp))

# Retrieve the number of annotations
num_annotations = len({tup[-2] for tup in annotations_all})
# Get dict with paths to all files in artifact and get number of annotations
annotations_dict = annotations.sample_dict()
num_annotations = len(annotations_dict)

# If no number of partitions is specified or the number is higher than the number
# of annotations, all annotations get partitioned by annotation
Expand All @@ -162,39 +149,32 @@ def _partition_annotations(
)
num_partitions = num_annotations

# Splits annotations into the specified number of partitions
arrays = np.array_split(annotations_all, num_partitions)
# Split the dict into a list of specified number of dicts
i = itertools.cycle(range(num_partitions))
annotations_split_list = [{} for _ in range(num_partitions)]
for key, value in annotations_dict.items():
annotations_split_list[next(i)][key] = value

for i, annotation_tuple in enumerate(arrays, 1):
for i, annotations_split in enumerate(annotations_split_list, 1):
# Creates directory with same format as input
partitioned_annotation = type(annotations)()

# Constructs paths to all annotation files and move them to the new partition
# directories
if isinstance(annotations, CARDAnnotationDirectoryFormat):
for sample_id, mag_id, annotation_fp in annotation_tuple:
annotation_des_fp = os.path.join(
for sample_mag_id, file_path_list in annotations_split.items():
for file_path in file_path_list:
file_path_des = os.path.join(
partitioned_annotation.path,
sample_id,
mag_id,
os.path.basename(annotation_fp),
sample_mag_id,
os.path.basename(file_path),
)
os.makedirs(os.path.dirname(annotation_des_fp), exist_ok=True)
duplicate(annotation_fp, annotation_des_fp)

partitioned_annotation_key = mag_id
os.makedirs(os.path.dirname(file_path_des), exist_ok=True)
duplicate(file_path, file_path_des)

if isinstance(annotations, CARDAnnotationDirectoryFormat):
partitioned_annotation_key = sample_mag_id.replace("/", "_")
else:
for sample_id, annotation_fp in annotation_tuple:
annotation_des_fp = os.path.join(
partitioned_annotation.path,
sample_id,
os.path.basename(annotation_fp),
)
os.makedirs(os.path.dirname(annotation_des_fp), exist_ok=True)
duplicate(annotation_fp, annotation_des_fp)

partitioned_annotation_key = sample_id
partitioned_annotation_key = sample_mag_id

# Add the partitioned object to the collection
if num_partitions == num_annotations: # and not duplicates:
Expand Down
70 changes: 37 additions & 33 deletions q2_amr/card/tests/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,18 @@ def test_partition_mags_annotations(self):
# Run partition_mags_annotations
obs = partition_mags_annotations(annotations=annotations, num_partitions=2)

mag_ids = [
"f5a16381-ea80-49f9-875e-620f333a9293",
"e026af61-d911-4de3-a957-7e8bf837f30d",
collection_keys = [
"sample2_f5a16381-ea80-49f9-875e-620f333a9293",
"sample1_e026af61-d911-4de3-a957-7e8bf837f30d",
]

# Assert if keys of collection are correct
self.assertTrue(set(obs.keys()) == set(mag_ids))
self.assertTrue(set(obs.keys()) == set(collection_keys))

# Assert if all files exist in the correct locations
for mag_id, samp in zip(mag_ids, ["sample2", "sample1"]):
for mag_id, samp in zip(collection_keys, ["sample2", "sample1"]):
for file in ["amr_annotation.json", "amr_annotation.txt"]:
path = os.path.join(obs[mag_id].path, samp, mag_id, file)
path = os.path.join(obs[mag_id].path, samp, mag_id[8:], file)
self.assertTrue(os.path.exists(path))

def test_partition_mags_warning_message(self):
Expand All @@ -197,39 +197,43 @@ def test_partition_mags_warning_message(self):
partition_mags_annotations(annotations=annotations, num_partitions=5)

def test_partition_reads_allele_annotations(self):
self._test_partition_reads_annotations(
dir="collated/annotate_reads_allele_output",
files=[
"allele_mapping_data.txt",
"overall_mapping_stats.txt",
"sorted.length_100.bam",
],
format=CARDAlleleAnnotationDirectoryFormat,
function=partition_reads_allele_annotations,
# Set up for annotations
annotations = self.setup_annotations(
"collated/annotate_reads_allele_output", CARDAlleleAnnotationDirectoryFormat
)
# Run function
obs = partition_reads_allele_annotations(annotations=annotations)

def test_partition_reads_gene_annotations(self):
self._test_partition_reads_annotations(
dir="collated/annotate_reads_gene_output",
files=["gene_mapping_data.txt"],
format=CARDGeneAnnotationDirectoryFormat,
function=partition_reads_gene_annotations,
)
# Assert if keys of collection are correct
self.assertTrue(set(obs.keys()) == {"sample2", "sample1"})

def _test_partition_reads_annotations(self, dir, files, format, function):
# Set up for annotations
annotations = self.setup_annotations(dir, format)
# Assert if all files exist in the right location
file_paths = [
os.path.join(obs["sample2"].path, "sample2", "allele_mapping_data.txt"),
os.path.join(obs["sample2"].path, "sample2", "overall_mapping_stats.txt"),
os.path.join(obs["sample2"].path, "sample2", "sorted.length_100.bam"),
os.path.join(obs["sample1"].path, "sample1", "allele_mapping_data.txt"),
os.path.join(obs["sample1"].path, "sample1", "overall_mapping_stats.txt"),
os.path.join(obs["sample1"].path, "sample1", "sorted.length_100.bam"),
]
for file_path in file_paths:
self.assertTrue(os.path.exists(file_path))

def test_partition_reads_gene_annotations(self):
# Set up for annotations
annotations = self.setup_annotations(
"collated/annotate_reads_gene_output", CARDGeneAnnotationDirectoryFormat
)
# Run function
obs = function(annotations=annotations)

samples = ["sample2", "sample1"]
obs = partition_reads_gene_annotations(annotations=annotations)

# Assert if keys of collection are correct
self.assertTrue(set(obs.keys()) == set(samples))
self.assertTrue(set(obs.keys()) == {"sample2", "sample1"})

# Assert if all files exist in the right location
for key, sample in zip(list(obs.keys()), samples):
for file in files:
path = os.path.join(obs[key].path, sample, file)
self.assertTrue(os.path.exists(path))
file_paths = [
os.path.join(obs["sample2"].path, "sample2", "gene_mapping_data.txt"),
os.path.join(obs["sample1"].path, "sample1", "gene_mapping_data.txt"),
]
for file_path in file_paths:
self.assertTrue(os.path.exists(file_path))

0 comments on commit e78db60

Please sign in to comment.