From e78db60431f2d22e29b63d807da5e80ac83e7a4b Mon Sep 17 00:00:00 2001 From: VinzentRisch Date: Thu, 11 Apr 2024 16:31:39 +0200 Subject: [PATCH] changed partition to work with new sample_dict function changed tests for more readability --- q2_amr/card/partition.py | 60 +++++++++---------------- q2_amr/card/tests/test_partition.py | 70 +++++++++++++++-------------- 2 files changed, 57 insertions(+), 73 deletions(-) diff --git a/q2_amr/card/partition.py b/q2_amr/card/partition.py index afed9cb..c801578 100644 --- a/q2_amr/card/partition.py +++ b/q2_amr/card/partition.py @@ -1,8 +1,8 @@ +import itertools import os import warnings from typing import Union -import numpy as np from qiime2.util import duplicate from q2_amr.types import ( @@ -130,23 +130,10 @@ def _partition_annotations( num_partitions: int = None, ): partitioned_annotations = {} - annotations_all = [] - # Add one tuple with sample id MAG id and full paths to annotation files to - # annotations_all per annotation file - if isinstance(annotations, CARDAnnotationDirectoryFormat): - for sample_id, mag in annotations.sample_dict().items(): - for mag_id, annotation_fp_list in mag.items(): - for annotation_fp in annotation_fp_list: - annotations_all.append((sample_id, mag_id, annotation_fp)) - - else: - for sample_id, annotation_fp_list in annotations.sample_dict().items(): - for annotation_fp in annotation_fp_list: - annotations_all.append((sample_id, annotation_fp)) - - # Retrieve the number of annotations - num_annotations = len({tup[-2] for tup in annotations_all}) + # Get dict with paths to all files in artifact and get number of annotations + annotations_dict = annotations.sample_dict() + num_annotations = len(annotations_dict) # If no number of partitions is specified or the number is higher than the number # of annotations, all annotations get partitioned by annotation @@ -162,39 +149,32 @@ def _partition_annotations( ) num_partitions = num_annotations - # Splits annotations into the specified number of partitions - arrays = np.array_split(annotations_all, num_partitions) + # Split the dict into a list of specified number of dicts + i = itertools.cycle(range(num_partitions)) + annotations_split_list = [{} for _ in range(num_partitions)] + for key, value in annotations_dict.items(): + annotations_split_list[next(i)][key] = value - for i, annotation_tuple in enumerate(arrays, 1): + for i, annotations_split in enumerate(annotations_split_list, 1): # Creates directory with same format as input partitioned_annotation = type(annotations)() # Constructs paths to all annotation files and move them to the new partition # directories - if isinstance(annotations, CARDAnnotationDirectoryFormat): - for sample_id, mag_id, annotation_fp in annotation_tuple: - annotation_des_fp = os.path.join( + for sample_mag_id, file_path_list in annotations_split.items(): + for file_path in file_path_list: + file_path_des = os.path.join( partitioned_annotation.path, - sample_id, - mag_id, - os.path.basename(annotation_fp), + sample_mag_id, + os.path.basename(file_path), ) - os.makedirs(os.path.dirname(annotation_des_fp), exist_ok=True) - duplicate(annotation_fp, annotation_des_fp) - - partitioned_annotation_key = mag_id + os.makedirs(os.path.dirname(file_path_des), exist_ok=True) + duplicate(file_path, file_path_des) + if isinstance(annotations, CARDAnnotationDirectoryFormat): + partitioned_annotation_key = sample_mag_id.replace("/", "_") else: - for sample_id, annotation_fp in annotation_tuple: - annotation_des_fp = os.path.join( - partitioned_annotation.path, - sample_id, - os.path.basename(annotation_fp), - ) - os.makedirs(os.path.dirname(annotation_des_fp), exist_ok=True) - duplicate(annotation_fp, annotation_des_fp) - - partitioned_annotation_key = sample_id + partitioned_annotation_key = sample_mag_id # Add the partitioned object to the collection if num_partitions == num_annotations: # and not duplicates: diff --git a/q2_amr/card/tests/test_partition.py b/q2_amr/card/tests/test_partition.py index 3a4b4dd..6c23196 100644 --- a/q2_amr/card/tests/test_partition.py +++ b/q2_amr/card/tests/test_partition.py @@ -170,18 +170,18 @@ def test_partition_mags_annotations(self): # Run partition_mags_annotations obs = partition_mags_annotations(annotations=annotations, num_partitions=2) - mag_ids = [ - "f5a16381-ea80-49f9-875e-620f333a9293", - "e026af61-d911-4de3-a957-7e8bf837f30d", + collection_keys = [ + "sample2_f5a16381-ea80-49f9-875e-620f333a9293", + "sample1_e026af61-d911-4de3-a957-7e8bf837f30d", ] # Assert if keys of collection are correct - self.assertTrue(set(obs.keys()) == set(mag_ids)) + self.assertTrue(set(obs.keys()) == set(collection_keys)) # Assert if all files exist in the correct locations - for mag_id, samp in zip(mag_ids, ["sample2", "sample1"]): + for mag_id, samp in zip(collection_keys, ["sample2", "sample1"]): for file in ["amr_annotation.json", "amr_annotation.txt"]: - path = os.path.join(obs[mag_id].path, samp, mag_id, file) + path = os.path.join(obs[mag_id].path, samp, mag_id[8:], file) self.assertTrue(os.path.exists(path)) def test_partition_mags_warning_message(self): @@ -197,39 +197,43 @@ def test_partition_mags_warning_message(self): partition_mags_annotations(annotations=annotations, num_partitions=5) def test_partition_reads_allele_annotations(self): - self._test_partition_reads_annotations( - dir="collated/annotate_reads_allele_output", - files=[ - "allele_mapping_data.txt", - "overall_mapping_stats.txt", - "sorted.length_100.bam", - ], - format=CARDAlleleAnnotationDirectoryFormat, - function=partition_reads_allele_annotations, + # Set up for annotations + annotations = self.setup_annotations( + "collated/annotate_reads_allele_output", CARDAlleleAnnotationDirectoryFormat ) + # Run function + obs = partition_reads_allele_annotations(annotations=annotations) - def test_partition_reads_gene_annotations(self): - self._test_partition_reads_annotations( - dir="collated/annotate_reads_gene_output", - files=["gene_mapping_data.txt"], - format=CARDGeneAnnotationDirectoryFormat, - function=partition_reads_gene_annotations, - ) + # Assert if keys of collection are correct + self.assertTrue(set(obs.keys()) == {"sample2", "sample1"}) - def _test_partition_reads_annotations(self, dir, files, format, function): - # Set up for annotations - annotations = self.setup_annotations(dir, format) + # Assert if all files exist in the right location + file_paths = [ + os.path.join(obs["sample2"].path, "sample2", "allele_mapping_data.txt"), + os.path.join(obs["sample2"].path, "sample2", "overall_mapping_stats.txt"), + os.path.join(obs["sample2"].path, "sample2", "sorted.length_100.bam"), + os.path.join(obs["sample1"].path, "sample1", "allele_mapping_data.txt"), + os.path.join(obs["sample1"].path, "sample1", "overall_mapping_stats.txt"), + os.path.join(obs["sample1"].path, "sample1", "sorted.length_100.bam"), + ] + for file_path in file_paths: + self.assertTrue(os.path.exists(file_path)) + def test_partition_reads_gene_annotations(self): + # Set up for annotations + annotations = self.setup_annotations( + "collated/annotate_reads_gene_output", CARDGeneAnnotationDirectoryFormat + ) # Run function - obs = function(annotations=annotations) - - samples = ["sample2", "sample1"] + obs = partition_reads_gene_annotations(annotations=annotations) # Assert if keys of collection are correct - self.assertTrue(set(obs.keys()) == set(samples)) + self.assertTrue(set(obs.keys()) == {"sample2", "sample1"}) # Assert if all files exist in the right location - for key, sample in zip(list(obs.keys()), samples): - for file in files: - path = os.path.join(obs[key].path, sample, file) - self.assertTrue(os.path.exists(path)) + file_paths = [ + os.path.join(obs["sample2"].path, "sample2", "gene_mapping_data.txt"), + os.path.join(obs["sample1"].path, "sample1", "gene_mapping_data.txt"), + ] + for file_path in file_paths: + self.assertTrue(os.path.exists(file_path))