From cb3a36f373688293c36cf88a3c49ce7ccbded1fe Mon Sep 17 00:00:00 2001 From: Robert Petryszak Date: Thu, 19 Oct 2023 14:44:18 +0100 Subject: [PATCH] Incorporating linter suggestions (WIP) --- cellphonedb/utils/db_releases_utils.py | 31 ++-- cellphonedb/utils/db_utils.py | 243 +++++++++++++++++-------- cellphonedb/utils/file_utils.py | 23 ++- cellphonedb/utils/search_utils.py | 105 ++++++----- 4 files changed, 261 insertions(+), 141 deletions(-) diff --git a/cellphonedb/utils/db_releases_utils.py b/cellphonedb/utils/db_releases_utils.py index c43c72bd..ae8d0332 100644 --- a/cellphonedb/utils/db_releases_utils.py +++ b/cellphonedb/utils/db_releases_utils.py @@ -43,17 +43,7 @@ def get_remote_database_versions_html(include_file_browsing: bool = False, min_v rel_version = float('.'.join(rel_tag.replace('v', '').split(".")[0:2])) if rel_version < min_version: continue - html += "{}" \ - .format(css_style, rel['html_url'], rel['tag_name']) - html += "{}".format(css_style, rel['published_at'].split("T")[0]) - if include_file_browsing: - html += ("" + - "pageview").format(css_style, rel['tag_name']) - html += "" + html += get_release_info(rel, css_style, include_file_browsing) html += "" html += "" result['db_releases_html_table'] = html @@ -67,6 +57,25 @@ def get_remote_database_versions_html(include_file_browsing: bool = False, min_v return result +def get_release_info( + rel: dict, + css_style: str, + include_file_browsing: bool +) -> str: + html = "{}" \ + .format(css_style, rel['html_url'], rel['tag_name']) + html += "{}".format(css_style, rel['published_at'].split("T")[0]) + if include_file_browsing: + html += ("" + + "pageview").format(css_style, rel['tag_name']) + html += "" + return html + + def _github_query(kind) -> Union[dict, list]: queries = { 'releases': 'https://api.github.com/repos/{}/{}/releases'.format('ventolab', 'cellphonedb-data'), diff --git a/cellphonedb/utils/db_utils.py b/cellphonedb/utils/db_utils.py index 2a17d802..06f48f2d 100644 --- a/cellphonedb/utils/db_utils.py +++ b/cellphonedb/utils/db_utils.py @@ -224,6 +224,59 @@ def get_column_names_for_db_version(complex_db_df, interactions_df, protein_df) return (protein_column_names, interaction_column_names1, interaction_column_names2, version, complex_columns) +def collect_protein_data(data_dfs: dict) -> pd.DataFrame: + # Collect protein data + protein_db_df = data_dfs['protein_input'][['protein_name', 'tags', 'tags_reason', 'tags_description', 'uniprot']] + num_proteins = protein_db_df.shape[0] + multidata_id_list_so_far = list(range(num_proteins)) + protein_db_df.insert(0, 'id_protein', multidata_id_list_so_far, False) + protein_db_df.insert(len(protein_db_df.columns), 'protein_multidata_id', protein_db_df['id_protein'].tolist(), True) + # dbg(protein_db_df.info) + return protein_db_df, multidata_id_list_so_far + + +def collect_gene_data(data_dfs: dict, protein_db_df: pd.DataFrame) -> pd.DataFrame: + gene_db_df = data_dfs['gene_input'][['ensembl', 'gene_name', 'hgnc_symbol', 'uniprot']] + num_genes = gene_db_df.shape[0] + gene_db_df.insert(0, 'id_gene', list(range(num_genes)), False) + # Assign values from protein_db_df['protein_multidata_id'] into gene_db_df['protein_id'] + # via join between 'uniprot' and 'protein_name' + gene_db_df = pd.merge(gene_db_df, protein_db_df[['protein_name', 'protein_multidata_id', 'uniprot']], on='uniprot') + gene_db_df = gene_db_df.drop('uniprot', axis=1) + protein_db_df = protein_db_df.drop('uniprot', axis=1) + gene_db_df.rename(columns={'protein_multidata_id': 'protein_id'}, inplace=True) + # print(gene_db_df.info) + return gene_db_df + + +def collect_receptor_to_tf_mapping(data_dfs: dict) -> pd.DataFrame: + receptor_to_tf_df = None + # Cater for DB version-dependent input files + if data_dfs['transcription_factor_input'] is not None: + receptor_to_tf_df = data_dfs['transcription_factor_input'][['receptor_id', 'TF_symbol']].copy() + receptor_to_tf_df.columns = ['Receptor', 'TF'] + # Strip any leading or trailing spaces + receptor_to_tf_df.replace(r'\s*(.*?)\s*', r'\1', regex=True, inplace=True) + return receptor_to_tf_df + + +def collect_gene_synonym_to_gene_name_mapping(data_dfs: dict, gene_db_df: pd.DataFrame) -> pd.DataFrame: + gene_synonym_to_gene_name_db_df = None + # Cater for DB version-dependent input files + if data_dfs['gene_synonyms_input'] is not None: + gene_synonym_to_gene_name = {} + for gene_names in data_dfs['gene_synonyms_input']\ + .filter(regex=("Gene Names.*")).dropna().agg(' '.join, axis=1).tolist(): + gene_names_arr = re.split(';\\s*|\\s+', gene_names) + for gene_name in gene_db_df[gene_db_df['gene_name'].isin(gene_names_arr)]['gene_name'].tolist(): + for gene_synonym in gene_names_arr: + if gene_synonym != gene_name: + gene_synonym_to_gene_name[gene_synonym] = gene_name + gene_synonym_to_gene_name_db_df = pd.DataFrame(gene_synonym_to_gene_name.items(), + columns=['Gene Synonym', 'Gene Name']) + return gene_synonym_to_gene_name_db_df + + def create_db(target_dir) -> None: """ Creates CellphoneDB databases file (cellphonedb.zip) in directory. @@ -250,64 +303,32 @@ def create_db(target_dir) -> None: gene_synonyms_input = os.path.join(target_dir, "sources/uniprot_synonyms.tsv") pathlib.Path(target_dir).mkdir(parents=True, exist_ok=True) - dataDFs = getDFs(gene_input=gene_input, protein_input=protein_input, complex_input=complex_input, - interaction_input=interaction_input, transcription_factor_input=transcription_factor_input, - gene_synonyms_input=gene_synonyms_input) + data_dfs = get_dfs(gene_input=gene_input, protein_input=protein_input, complex_input=complex_input, + interaction_input=interaction_input, transcription_factor_input=transcription_factor_input, + gene_synonyms_input=gene_synonyms_input) (protein_column_names, interaction_column_names1, interaction_column_names2, version, complex_columns) = \ get_column_names_for_db_version( - dataDFs['complex_input'], dataDFs['interaction_input'], dataDFs['protein_input']) + data_dfs['complex_input'], data_dfs['interaction_input'], data_dfs['protein_input']) # Perform sanity tests on *_input files and report any issues to the user as warnings - run_sanity_tests(dataDFs, protein_column_names, version) + run_sanity_tests(data_dfs, protein_column_names, version) # Collect protein data - protein_db_df = dataDFs['protein_input'][['protein_name', 'tags', 'tags_reason', 'tags_description', 'uniprot']] - num_proteins = protein_db_df.shape[0] - multidata_id_list_so_far = list(range(num_proteins)) - protein_db_df.insert(0, 'id_protein', multidata_id_list_so_far, False) - protein_db_df.insert(len(protein_db_df.columns), 'protein_multidata_id', protein_db_df['id_protein'].tolist(), True) - # dbg(protein_db_df.info) + protein_db_df, multidata_id_list_so_far = collect_protein_data(data_dfs) # Collect gene data - gene_db_df = dataDFs['gene_input'][['ensembl', 'gene_name', 'hgnc_symbol', 'uniprot']] - num_genes = gene_db_df.shape[0] - gene_db_df.insert(0, 'id_gene', list(range(num_genes)), False) - # Assign values from protein_db_df['protein_multidata_id'] into gene_db_df['protein_id'] - # via join between 'uniprot' and 'protein_name' - gene_db_df = pd.merge(gene_db_df, protein_db_df[['protein_name', 'protein_multidata_id', 'uniprot']], on='uniprot') - gene_db_df = gene_db_df.drop('uniprot', axis=1) - protein_db_df = protein_db_df.drop('uniprot', axis=1) - gene_db_df.rename(columns={'protein_multidata_id': 'protein_id'}, inplace=True) - # print(gene_db_df.info) + gene_db_df = collect_gene_data(data_dfs, protein_db_df) # Collect mapping: (receptor) gene name -> TF gene name (in transcription_factor_input.tsv) - receptor_to_tf_df = None - # Cater for DB version-dependent input files - if dataDFs['transcription_factor_input'] is not None: - receptor_to_tf_df = dataDFs['transcription_factor_input'][['receptor_id', 'TF_symbol']].copy() - receptor_to_tf_df.columns = ['Receptor', 'TF'] - # Strip any leading or trailing spaces - receptor_to_tf_df.replace(r'\s*(.*?)\s*', r'\1', regex=True, inplace=True) + receptor_to_tf_df = collect_receptor_to_tf_mapping(data_dfs) # Collect mapping: gene synonym (not in gene_input.csv) -> gene name (in gene_input.csv) - gene_synonym_to_gene_name_db_df = None - # Cater for DB version-dependent input files - if dataDFs['gene_synonyms_input'] is not None: - gene_synonym_to_gene_name = {} - for gene_names in dataDFs['gene_synonyms_input']\ - .filter(regex=("Gene Names.*")).dropna().agg(' '.join, axis=1).tolist(): - gene_names_arr = re.split(';\\s*|\\s+', gene_names) - for gene_name in gene_db_df[gene_db_df['gene_name'].isin(gene_names_arr)]['gene_name'].tolist(): - for gene_synonym in gene_names_arr: - if gene_synonym != gene_name: - gene_synonym_to_gene_name[gene_synonym] = gene_name - gene_synonym_to_gene_name_db_df = pd.DataFrame(gene_synonym_to_gene_name.items(), - columns=['Gene Synonym', 'Gene Name']) + gene_synonym_to_gene_name_db_df = collect_gene_synonym_to_gene_name_mapping(data_dfs, gene_db_df) # Collect complex data cols = protein_column_names + ['pdb_structure', 'pdb_id', 'stoichiometry', 'comments_complex'] + complex_columns - complex_db_df = dataDFs['complex_input'][cols] + complex_db_df = data_dfs['complex_input'][cols] # Note that uniprot_* cols will be dropped after complex_composition_df has been constructed num_complexes = complex_db_df.shape[0] @@ -322,11 +343,11 @@ def create_db(target_dir) -> None: # Collect multidata # Insert proteins into multidata multidata_db_df = \ - dataDFs['protein_input'][['uniprot', 'receptor', 'receptor_desc', 'other', 'other_desc', 'secreted_highlight', + data_dfs['protein_input'][['uniprot', 'receptor', 'receptor_desc', 'other', 'other_desc', 'secreted_highlight', 'secreted_desc', 'transmembrane', 'secreted', 'peripheral', 'integrin']].copy() multidata_db_df.rename(columns={'uniprot': 'name'}, inplace=True) multidata_ids = pd.merge( - dataDFs['protein_input'][['protein_name']], + data_dfs['protein_input'][['protein_name']], protein_db_df[['protein_name', 'protein_multidata_id']], on='protein_name')['protein_multidata_id'].tolist() multidata_db_df.insert(0, 'id_multidata', multidata_ids, False) multidata_db_df.insert(len(multidata_db_df.columns), 'is_complex', @@ -335,7 +356,7 @@ def create_db(target_dir) -> None: # Insert complexes into multidata cols = ['complex_name', 'receptor', 'receptor_desc', 'other', 'other_desc', 'secreted_highlight', 'secreted_desc', 'transmembrane', 'secreted', 'peripheral', 'integrin'] - complex_aux_df = dataDFs['complex_input'][cols].copy() + complex_aux_df = data_dfs['complex_input'][cols].copy() complex_aux_df.rename(columns={'complex_name': 'name'}, inplace=True) complex_aux_df.insert(0, 'id_multidata', complex_multidata_ids, False) complex_aux_df.insert(len(complex_aux_df.columns), 'is_complex', @@ -371,7 +392,7 @@ def create_db(target_dir) -> None: complex_db_df = complex_db_df.drop(col, axis=1) # Collect interaction data - interactions_aux_df = pd.merge(dataDFs['interaction_input'], multidata_db_df, + interactions_aux_df = pd.merge(data_dfs['interaction_input'], multidata_db_df, left_on=['partner_a'], right_on=['name']) interactions_aux_df = pd.merge(interactions_aux_df, multidata_db_df, left_on=['partner_b'], right_on=['name'], suffixes=['_x', '_y']) @@ -436,8 +457,8 @@ def download_released_files(target_dir, cpdb_version, regex): print("Downloaded {} into {}".format(fname, target_dir)) -def getDFs(gene_input=None, protein_input=None, complex_input=None, interaction_input=None, - transcription_factor_input=None, gene_synonyms_input=None): +def get_dfs(gene_input=None, protein_input=None, complex_input=None, interaction_input=None, + transcription_factor_input=None, gene_synonyms_input=None): dfs = {} dfs['gene_input'] = file_utils.read_data_table_from_file(gene_input) dfs['protein_input'] = file_utils.read_data_table_from_file(protein_input) @@ -448,18 +469,7 @@ def getDFs(gene_input=None, protein_input=None, complex_input=None, interaction_ return dfs -def run_sanity_tests(dataDFs, protein_column_names, version): - data_errors_found = False - protein_db_df = dataDFs['protein_input'] - complex_db_df = dataDFs['complex_input'] - gene_db_df = dataDFs['gene_input'] - interaction_db_df = dataDFs['interaction_input'] - tf_input_df = None - # Cater for DB version-dependent input files - if 'transcription_factor_input' in dataDFs: - tf_input_df = dataDFs['transcription_factor_input'] - - # 1. Report any uniprot accessions that map to multiple gene_names +def sanity_test_uniprot_accessions_map_to_multiple_gene_names(gene_db_df: pd.DataFrame): gene_names_uniprot_df = gene_db_df[['gene_name', 'uniprot']].copy() gene_names_uniprot_df.drop_duplicates(inplace=True) dups = gene_names_uniprot_df[gene_names_uniprot_df['uniprot'].duplicated() == True] @@ -469,17 +479,16 @@ def run_sanity_tests(dataDFs, protein_column_names, version): "they should map to only one):") print(", ".join(dups['uniprot'].tolist())) - # 2. Warn about complex name duplicates in complex_db_df + +def sanity_test_report_complex_duplicates(complex_db_df: pd.DataFrame): test_complex_db_df = complex_db_df.set_index('complex_name') if test_complex_db_df.index.has_duplicates: print("WARNING: complex_input.csv has the following duplicates:") print("\n".join(complex_db_df[test_complex_db_df.index.duplicated(keep='first')]['complex_name'].tolist()) + "\n") - # 3. Report complexes with (possibly) different names, but with the same uniprot - # accession participants (though not necessarily in the same order - hence the use of set below) - # NB. Use set below as we don't care about the order of participants when looking for duplicates - # NB. Report duplicate complexes _only if_ at least one duplicate's complex_db_df['version'] - # does not start with CORE_CELLPHONEDB_DATA) + +def sanity_test_report_complexes_with_same_participants( + complex_db_df: pd.DataFrame, version: str, protein_column_names: list): participants_set_to_complex_names = {} participants_set_to_data_sources = {} cols = ['complex_name'] + version + protein_column_names @@ -511,8 +520,8 @@ def run_sanity_tests(dataDFs, protein_column_names, version): print("WARNING: The following multiple complexes (left) appear to have the same composition (right):") print(complex_dups) - # 4. Report interactions with (possibly) a different name, but with the same participants - # (though not necessarily in the same order - hence the use of set below) + +def sanity_test_report_interactions_with_same_participants(interaction_db_df: pd.DataFrame): partner_sets = [set([i for i in row]) for row in interaction_db_df[['partner_a', 'partner_b']].itertuples(index=False)] # Find duplicate sets of partners @@ -525,13 +534,15 @@ def run_sanity_tests(dataDFs, protein_column_names, version): print(','.join(dup)) print() - # 5. Warn about uniprot accession duplicates in protein_db_df + +def sanity_test_report_uniprot_accession_duplicates(protein_db_df: pd.DataFrame): test_protein_db_df = protein_db_df.set_index('uniprot') if test_protein_db_df.index.has_duplicates: print("WARNING: protein_input.csv has the following UniProt accession duplicates:") print("\n".join(protein_db_df[test_protein_db_df.index.duplicated(keep='first')]['uniprot'].tolist()) + "\n") - # 6. Warn the user if some complexes don't participate in any interactions + +def sanity_test_report_orphan_complexes(complex_db_df: pd.DataFrame, interaction_db_df: pd.DataFrame) -> set: all_complexes_set = set(complex_db_df['complex_name'].tolist()) interaction_participants_set = set(interaction_db_df['partner_a'].tolist() + interaction_db_df['partner_b'].tolist()) orphan_complexes = all_complexes_set - interaction_participants_set @@ -539,10 +550,18 @@ def run_sanity_tests(dataDFs, protein_column_names, version): print("WARNING: The following complexes are not found in interaction_input.txt:") print("\n".join(orphan_complexes)) print() + return orphan_complexes - # 7. Warn the user if some proteins don't participate in any interactions directly, - # or are part some complex in orphan_complexes + +def sanity_test_report_orphan_proteins( + protein_db_df: pd.DataFrame, + complex_db_df: pd.DataFrame, + interaction_db_df: pd.DataFrame, + protein_column_names: list, + orphan_complexes: set): all_proteins_set = set(protein_db_df['uniprot'].tolist()) + interaction_participants_set = set( + interaction_db_df['partner_a'].tolist() + interaction_db_df['partner_b'].tolist()) proteins_in_complexes_participating_in_interactions = [] for colName in protein_column_names: proteins_in_complexes_participating_in_interactions += \ @@ -554,21 +573,30 @@ def run_sanity_tests(dataDFs, protein_column_names, version): "or via complexes they are part of):") print("\n".join(orphan_proteins)) - # 8. Warn the user if some interactions contain interactors that are neither - # in complex_input.csv or protein_input.csv + +def sanity_test_report_unknown_interactors( + protein_db_df: pd.DataFrame, + complex_db_df: pd.DataFrame, + interaction_db_df: pd.DataFrame +): unknown_interactors = set() for col in ['partner_a', 'partner_b']: aux_df = pd.merge(interaction_db_df, protein_db_df, left_on=col, right_on='uniprot', how='outer') unknown_interactor_proteins = set(aux_df[pd.isnull(aux_df['uniprot'])][col].tolist()) aux_df = pd.merge(interaction_db_df, complex_db_df, left_on=col, right_on='complex_name', how='outer') unknown_interactor_complexes = set(aux_df[pd.isnull(aux_df['complex_name'])][col].tolist()) - unknown_interactors = unknown_interactors.union(unknown_interactor_proteins.intersection(unknown_interactor_complexes)) + unknown_interactors = unknown_interactors.union( + unknown_interactor_proteins.intersection(unknown_interactor_complexes)) if unknown_interactors: print("WARNING: The following interactors in interaction_input.txt could not be found in either " + "protein_input.csv or complex_indput.csv:") print("\n".join(sorted(unknown_interactors)) + "\n") - # 9. Warn if some complexes contain proteins not in protein_input.csv + +def sanity_test_report_unknown_proteins( + protein_db_df: pd.DataFrame, + complex_db_df: pd.DataFrame, + protein_column_names: list): unknown_proteins = set() for col in protein_column_names: aux_df = pd.merge(complex_db_df, protein_db_df, left_on=col, right_on='uniprot', how='outer') @@ -578,7 +606,10 @@ def run_sanity_tests(dataDFs, protein_column_names, version): print("WARNING: The following proteins in complex_input.txt could not be found in protein_input.csv:") print("\n".join(sorted(unknown_proteins)) + "\n") - # 10. Warn if some proteins in protein_input.csv are not in gene_input.csv + +def sanity_test_report_proteins_not_in_genes_file( + protein_db_df: pd.DataFrame, + gene_db_df: pd.DataFrame): proteins = set(protein_db_df['uniprot'].tolist()) unknown_proteins = proteins.difference(set(gene_db_df['uniprot'].tolist())) if unknown_proteins: @@ -586,7 +617,11 @@ def run_sanity_tests(dataDFs, protein_column_names, version): print("\n".join(sorted(unknown_proteins)) + "\n") print() - # 11. Warn if some receptor ids in tf_input_df are in neither gene_input.csv or complex_input.csv + +def sanity_test_report_tfs_not_in_gene_or_complex_files( + gene_db_df: pd.DataFrame, + complex_db_df: pd.DataFrame, + tf_input_df: pd.DataFrame): if tf_input_df is not None: # Cater for DB version-dependent input files for (bioentity, df) in {"gene": gene_db_df, "complex": complex_db_df}.items(): @@ -604,5 +639,57 @@ def run_sanity_tests(dataDFs, protein_column_names, version): print("\n".join(set(bioentities_not_in_input))) print() + +def run_sanity_tests(data_dfs, protein_column_names, version): + data_errors_found = False + protein_db_df = data_dfs['protein_input'] + complex_db_df = data_dfs['complex_input'] + gene_db_df = data_dfs['gene_input'] + interaction_db_df = data_dfs['interaction_input'] + tf_input_df = None + # Cater for DB version-dependent input files + if 'transcription_factor_input' in data_dfs: + tf_input_df = data_dfs['transcription_factor_input'] + + # 1. Report any uniprot accessions that map to multiple gene_names + sanity_test_uniprot_accessions_map_to_multiple_gene_names(gene_db_df) + + # 2. Warn about complex name duplicates in complex_db_df + sanity_test_report_complex_duplicates(complex_db_df) + + # 3. Report complexes with (possibly) different names, but with the same uniprot + # accession participants (though not necessarily in the same order - hence the use of set below) + # NB. Use set below as we don't care about the order of participants when looking for duplicates + # NB. Report duplicate complexes _only if_ at least one duplicate's complex_db_df['version'] + # does not start with CORE_CELLPHONEDB_DATA) + sanity_test_report_complexes_with_same_participants(complex_db_df, version, protein_column_names) + + # 4. Report interactions with (possibly) a different name, but with the same participants + # (though not necessarily in the same order - hence the use of set below) + sanity_test_report_interactions_with_same_participants(interaction_db_df) + + # 5. Warn about uniprot accession duplicates in protein_db_df + sanity_test_report_uniprot_accession_duplicates(protein_db_df) + + # 6. Warn the user if some complexes don't participate in any interactions + orphan_complexes = sanity_test_report_orphan_complexes(complex_db_df, interaction_db_df) + + # 7. Warn the user if some proteins don't participate in any interactions directly, + # or are part of some complex in orphan_complexes + sanity_test_report_orphan_proteins(protein_db_df, complex_db_df, interaction_db_df, protein_column_names, orphan_complexes) + + # 8. Warn the user if some interactions contain interactors that are neither + # in complex_input.csv or protein_input.csv + sanity_test_report_unknown_interactors(protein_db_df, complex_db_df, interaction_db_df) + + # 9. Warn if some complexes contain proteins not in protein_input.csv + sanity_test_report_unknown_proteins(protein_db_df, complex_db_df, protein_column_names) + + # 10. Warn if some proteins in protein_input.csv are not in gene_input.csv + sanity_test_report_proteins_not_in_genes_file(protein_db_df, gene_db_df) + + # 11. Warn if some receptor ids in tf_input_df are in neither gene_input.csv or complex_input.csv + sanity_test_report_tfs_not_in_gene_or_complex_files(gene_db_df, complex_db_df, tf_input_df) + if data_errors_found: raise DatabaseCreationException diff --git a/cellphonedb/utils/file_utils.py b/cellphonedb/utils/file_utils.py index 16e4792f..10db7c41 100755 --- a/cellphonedb/utils/file_utils.py +++ b/cellphonedb/utils/file_utils.py @@ -28,17 +28,8 @@ def read_data_table_from_file(file: str, index_column_first: bool = False, separ return _read_h5ad(file) if file_extension == '.h5': return _read_h5(file) - if file_extension == '.pickle': - try: - with open(file, 'rb') as f: - df = pickle.load(f) - if isinstance(df, pd.DataFrame): - return df - else: - raise NotADataFrameException(file) - except Exception: - raise ReadFromPickleException(file) + return _read_pickle(file) if not separator: separator = _get_separator(file_extension) @@ -111,6 +102,18 @@ def _read_h5(path: str) -> pd.DataFrame: return df +def _read_pickle(path: str) -> pd.DataFrame: + try: + with open(path, 'rb') as f: + df = pickle.load(f) + if isinstance(df, pd.DataFrame): + return df + else: + raise NotADataFrameException(path) + except Exception: + raise ReadFromPickleException(path) + + def _read_data(file_stream: TextIO, separator: str, index_column_first: bool, dtype=None, na_values=None, compression=None) -> pd.DataFrame: return pd.read_csv(file_stream, sep=separator, index_col=0 if index_column_first else None, dtype=dtype, diff --git a/cellphonedb/utils/search_utils.py b/cellphonedb/utils/search_utils.py index 25e0bdc2..d384ac35 100644 --- a/cellphonedb/utils/search_utils.py +++ b/cellphonedb/utils/search_utils.py @@ -29,6 +29,44 @@ def populate_proteins_for_complex(complex_name, complex_name2proteins, genes, co complex_name2proteins[complex_name] = constituent_proteins +def assemble_multidata_ids_for_search( + query_str: str, + genes: pd.DataFrame, + complex_expanded: pd.DataFrame, + complex_composition: pd.DataFrame, + gene_synonym2gene_name: dict) -> list: + multidata_ids = [] + for token in re.split(',\\s*| ', query_str): + + if token in gene_synonym2gene_name: + # Map any gene synonyms not in gene_input to gene names in gene_input + token = gene_synonym2gene_name[token] + + complex_multidata_ids = [] + # Attempt to find token in genes (N.B. genes contains protein information also) + gene_protein_data_list = \ + genes['protein_multidata_id'][ + genes[['ensembl', 'gene_name', 'name', 'protein_name']] + .apply(lambda row: row.astype(str).eq(token).any(), axis=1) + ].to_list() + if (len(gene_protein_data_list) > 0): + multidata_ids += gene_protein_data_list + for protein_multidata_id in gene_protein_data_list: + complex_multidata_ids = \ + complex_composition['complex_multidata_id'][complex_composition['protein_multidata_id'] + == protein_multidata_id].to_list() + multidata_ids += complex_multidata_ids + else: + # No match in genes - attempt to find token in complex_expanded + complex_multidata_ids += \ + complex_expanded['complex_multidata_id'][ + complex_expanded[['name']] + .apply(lambda row: row.astype(str).eq(token).any(), axis=1) + ].to_list() + multidata_ids += complex_multidata_ids + return multidata_ids + + def search(query_str: str = "", cpdb_file_path: str = None) -> (list, map, map, map, map): """ @@ -64,35 +102,9 @@ def search(query_str: str = "", complex_name2proteins = {} # Assemble a list of multidata_ids to search interactions DF with - multidata_ids = [] - for token in re.split(',\\s*| ', query_str): - - if token in gene_synonym2gene_name: - # Map any gene synonyms not in gene_input to gene names in gene_input - token = gene_synonym2gene_name[token] + multidata_ids = assemble_multidata_ids_for_search( + query_str, genes, complex_expanded, complex_composition, gene_synonym2gene_name) - complex_multidata_ids = [] - # Attempt to find token in genes (N.B. genes contains protein information also) - gene_protein_data_list = \ - genes['protein_multidata_id'][ - genes[['ensembl', 'gene_name', 'name', 'protein_name']] - .apply(lambda row: row.astype(str).eq(token).any(), axis=1) - ].to_list() - if (len(gene_protein_data_list) > 0): - multidata_ids += gene_protein_data_list - for protein_multidata_id in gene_protein_data_list: - complex_multidata_ids = \ - complex_composition['complex_multidata_id'][complex_composition['protein_multidata_id'] - == protein_multidata_id].to_list() - multidata_ids += complex_multidata_ids - else: - # No match in genes - attempt to find token in complex_expanded - complex_multidata_ids += \ - complex_expanded['complex_multidata_id'][ - complex_expanded[['name']] - .apply(lambda row: row.astype(str).eq(token).any(), axis=1) - ].to_list() - multidata_ids += complex_multidata_ids # Now search for all multidata_ids in interactions duration = time.time() - start dbg("Output for query '{}':".format(query_str)) @@ -314,6 +326,28 @@ def return_all_identifiers(genes: pd.DataFrame, interactions: pd.DataFrame) -> p return result +def collect_celltype_pairs( + significant_means: pd.DataFrame, + query_cell_types_1: list, + query_cell_types_2: list, + separator: str) -> list: + if query_cell_types_1 is None or query_cell_types_2 is None: + cols_filter = significant_means.filter(regex="\\{}".format(separator)).columns + all_cts = set([]) + for ct_pair in [i.split(separator) for i in cols_filter.tolist()]: + all_cts |= set(ct_pair) + all_cell_types = list(all_cts) + if query_cell_types_1 is None: + query_cell_types_1 = all_cell_types + if query_cell_types_2 is None: + query_cell_types_2 = all_cell_types + cell_type_pairs = [] + for ct in query_cell_types_1: + for ct1 in query_cell_types_2: + cell_type_pairs += ["{}{}{}".format(ct, separator, ct1), "{}{}{}".format(ct1, separator, ct)] + return cell_type_pairs + + def search_analysis_results( query_cell_types_1: list = None, query_cell_types_2: list = None, @@ -370,20 +404,7 @@ def search_analysis_results( return # Collect all combinations of cell types (disregarding the order) from query_cell_types_1 and query_cell_types_2 - if query_cell_types_1 is None or query_cell_types_2 is None: - cols_filter = significant_means.filter(regex="\\{}".format(separator)).columns - all_cts = set([]) - for ct_pair in [i.split(separator) for i in cols_filter.tolist()]: - all_cts |= set(ct_pair) - all_cell_types = list(all_cts) - if query_cell_types_1 is None: - query_cell_types_1 = all_cell_types - if query_cell_types_2 is None: - query_cell_types_2 = all_cell_types - cell_type_pairs = [] - for ct in query_cell_types_1: - for ct1 in query_cell_types_2: - cell_type_pairs += ["{}{}{}".format(ct, separator, ct1), "{}{}{}".format(ct1, separator, ct)] + cell_type_pairs = collect_celltype_pairs(significant_means, query_cell_types_1, query_cell_types_2, separator) cols_filter = significant_means.columns[significant_means.columns.isin(cell_type_pairs)] # Collect all interactions from query_genes and query_interactions