diff --git a/rescript/ncbi_datasets.py b/rescript/ncbi_datasets.py index 4458eeb..052acc0 100644 --- a/rescript/ncbi_datasets.py +++ b/rescript/ncbi_datasets.py @@ -156,12 +156,14 @@ def _fetch_and_extract_dataset( def _fetch_taxonomy( all_acc_ids: list, all_tax_ids: list, - accession_to_assembly: pd.Series + accession_to_assembly: pd.Series, + ranks: list, + rank_propagation: bool ): manager = Manager() taxa, bad_accs = get_taxonomies( taxids={k: v for k, v in zip(all_acc_ids, all_tax_ids)}, - ranks=_default_ranks, rank_propagation=True, + ranks=ranks, rank_propagation=rank_propagation, logging_level='INFO', n_jobs=2, request_lock=manager.Lock() ) # technically, this should never happen as the taxa accession @@ -213,6 +215,8 @@ def get_ncbi_genomes( only_genomic: bool = False, tax_exact_match: bool = False, page_size: int = 20, + ranks: list = _default_ranks, + rank_propagation: bool = True, ) -> (DNAFASTAFormat, LociDirectoryFormat, ProteinsDirectoryFormat, pd.DataFrame): # we use a deepcopy of assembly_levels because the new versions of @@ -239,7 +243,9 @@ def get_ncbi_genomes( taxa = _fetch_taxonomy( assembly_to_taxon.keys(), assembly_to_taxon.values(), - accession_map.explode() + accession_map.explode(), + ranks, + rank_propagation ) return genomes, loci, proteins, taxa diff --git a/rescript/plugin_setup.py b/rescript/plugin_setup.py index 376f71f..a4d5247 100644 --- a/rescript/plugin_setup.py +++ b/rescript/plugin_setup.py @@ -1190,7 +1190,9 @@ 'assembly_levels': List[Str % Choices( ['complete_genome', 'chromosome', 'scaffold', 'contig'])], 'tax_exact_match': Bool, - 'page_size': Int % Range(20, 1000, inclusive_end=True) + 'page_size': Int % Range(20, 1000, inclusive_end=True), + 'ranks': List[Str % Choices(_allowed_ranks)], + 'rank_propagation': Bool, }, outputs=[ ('genome_assemblies', FeatureData[Sequence]), @@ -1217,6 +1219,10 @@ 'request. If number of genomes to fetch is higher than ' 'this number, requests will be repeated until all ' 'assemblies are fetched.', + 'ranks': 'List of taxonomic ranks for building a taxonomy from the ' + 'NCBI Taxonomy database.', + 'rank_propagation': RANK_PROPAGATE_DESCRIPTION, + }, output_descriptions={ 'genome_assemblies': 'Nucleotide sequences of requested genomes.', diff --git a/rescript/tests/test_ncbi_datasets.py b/rescript/tests/test_ncbi_datasets.py index 2f2f306..14151c9 100644 --- a/rescript/tests/test_ncbi_datasets.py +++ b/rescript/tests/test_ncbi_datasets.py @@ -198,7 +198,9 @@ def test_fetch_taxonomy(self, p): pd.Series( {'GCF_123': ['AC_12.1'], 'GCF_234': ['AC_23.2']}, name="assembly_id" - ).explode() + ).explode(), + _default_ranks, + True ) exp_taxa = pd.DataFrame( @@ -221,7 +223,11 @@ def test_fetch_taxonomy_bad_accs(self, p): Exception, r'Invalid taxonomy.*\: ACC1, ACC2. Please check.*' ): _fetch_taxonomy( - self.fake_assembly_ids, self.fake_tax_ids, pd.Series() + self.fake_assembly_ids, + self.fake_tax_ids, + pd.Series(), + _default_ranks, + True ) # just test that everything works together