diff --git a/cellphonedb/utils/db_releases_utils.py b/cellphonedb/utils/db_releases_utils.py
index c43c72b..ae8d033 100644
--- a/cellphonedb/utils/db_releases_utils.py
+++ b/cellphonedb/utils/db_releases_utils.py
@@ -43,17 +43,7 @@ def get_remote_database_versions_html(include_file_browsing: bool = False, min_v
rel_version = float('.'.join(rel_tag.replace('v', '').split(".")[0:2]))
if rel_version < min_version:
continue
- html += "
{} | " \
- .format(css_style, rel['html_url'], rel['tag_name'])
- html += "{} | ".format(css_style, rel['published_at'].split("T")[0])
- if include_file_browsing:
- html += ("" +
- "pageview").format(css_style, rel['tag_name'])
- html += "".format(rel['tag_name'])
- for file_name in ["gene_input", "protein_input", "complex_input", "interaction_input"]:
- html += "- {}
" \
- .format(rel['tag_name'], file_name, file_name)
- html += " | "
+ html += get_release_info(rel, css_style, include_file_browsing)
html += ""
html += ""
result['db_releases_html_table'] = html
@@ -67,6 +57,25 @@ def get_remote_database_versions_html(include_file_browsing: bool = False, min_v
return result
+def get_release_info(
+ rel: dict,
+ css_style: str,
+ include_file_browsing: bool
+) -> str:
+ html = "{} | " \
+ .format(css_style, rel['html_url'], rel['tag_name'])
+ html += "{} | ".format(css_style, rel['published_at'].split("T")[0])
+ if include_file_browsing:
+ html += ("" +
+ "pageview").format(css_style, rel['tag_name'])
+ html += "".format(rel['tag_name'])
+ for file_name in ["gene_input", "protein_input", "complex_input", "interaction_input"]:
+ html += "- {}
" \
+ .format(rel['tag_name'], file_name, file_name)
+ html += " | "
+ return html
+
+
def _github_query(kind) -> Union[dict, list]:
queries = {
'releases': 'https://api.github.com/repos/{}/{}/releases'.format('ventolab', 'cellphonedb-data'),
diff --git a/cellphonedb/utils/db_utils.py b/cellphonedb/utils/db_utils.py
index 2a17d80..06f48f2 100644
--- a/cellphonedb/utils/db_utils.py
+++ b/cellphonedb/utils/db_utils.py
@@ -224,6 +224,59 @@ def get_column_names_for_db_version(complex_db_df, interactions_df, protein_df)
return (protein_column_names, interaction_column_names1, interaction_column_names2, version, complex_columns)
+def collect_protein_data(data_dfs: dict) -> pd.DataFrame:
+ # Collect protein data
+ protein_db_df = data_dfs['protein_input'][['protein_name', 'tags', 'tags_reason', 'tags_description', 'uniprot']]
+ num_proteins = protein_db_df.shape[0]
+ multidata_id_list_so_far = list(range(num_proteins))
+ protein_db_df.insert(0, 'id_protein', multidata_id_list_so_far, False)
+ protein_db_df.insert(len(protein_db_df.columns), 'protein_multidata_id', protein_db_df['id_protein'].tolist(), True)
+ # dbg(protein_db_df.info)
+ return protein_db_df, multidata_id_list_so_far
+
+
+def collect_gene_data(data_dfs: dict, protein_db_df: pd.DataFrame) -> pd.DataFrame:
+ gene_db_df = data_dfs['gene_input'][['ensembl', 'gene_name', 'hgnc_symbol', 'uniprot']]
+ num_genes = gene_db_df.shape[0]
+ gene_db_df.insert(0, 'id_gene', list(range(num_genes)), False)
+ # Assign values from protein_db_df['protein_multidata_id'] into gene_db_df['protein_id']
+ # via join between 'uniprot' and 'protein_name'
+ gene_db_df = pd.merge(gene_db_df, protein_db_df[['protein_name', 'protein_multidata_id', 'uniprot']], on='uniprot')
+ gene_db_df = gene_db_df.drop('uniprot', axis=1)
+ protein_db_df = protein_db_df.drop('uniprot', axis=1)
+ gene_db_df.rename(columns={'protein_multidata_id': 'protein_id'}, inplace=True)
+ # print(gene_db_df.info)
+ return gene_db_df
+
+
+def collect_receptor_to_tf_mapping(data_dfs: dict) -> pd.DataFrame:
+ receptor_to_tf_df = None
+ # Cater for DB version-dependent input files
+ if data_dfs['transcription_factor_input'] is not None:
+ receptor_to_tf_df = data_dfs['transcription_factor_input'][['receptor_id', 'TF_symbol']].copy()
+ receptor_to_tf_df.columns = ['Receptor', 'TF']
+ # Strip any leading or trailing spaces
+ receptor_to_tf_df.replace(r'\s*(.*?)\s*', r'\1', regex=True, inplace=True)
+ return receptor_to_tf_df
+
+
+def collect_gene_synonym_to_gene_name_mapping(data_dfs: dict, gene_db_df: pd.DataFrame) -> pd.DataFrame:
+ gene_synonym_to_gene_name_db_df = None
+ # Cater for DB version-dependent input files
+ if data_dfs['gene_synonyms_input'] is not None:
+ gene_synonym_to_gene_name = {}
+ for gene_names in data_dfs['gene_synonyms_input']\
+ .filter(regex=("Gene Names.*")).dropna().agg(' '.join, axis=1).tolist():
+ gene_names_arr = re.split(';\\s*|\\s+', gene_names)
+ for gene_name in gene_db_df[gene_db_df['gene_name'].isin(gene_names_arr)]['gene_name'].tolist():
+ for gene_synonym in gene_names_arr:
+ if gene_synonym != gene_name:
+ gene_synonym_to_gene_name[gene_synonym] = gene_name
+ gene_synonym_to_gene_name_db_df = pd.DataFrame(gene_synonym_to_gene_name.items(),
+ columns=['Gene Synonym', 'Gene Name'])
+ return gene_synonym_to_gene_name_db_df
+
+
def create_db(target_dir) -> None:
"""
Creates CellphoneDB databases file (cellphonedb.zip) in directory.
@@ -250,64 +303,32 @@ def create_db(target_dir) -> None:
gene_synonyms_input = os.path.join(target_dir, "sources/uniprot_synonyms.tsv")
pathlib.Path(target_dir).mkdir(parents=True, exist_ok=True)
- dataDFs = getDFs(gene_input=gene_input, protein_input=protein_input, complex_input=complex_input,
- interaction_input=interaction_input, transcription_factor_input=transcription_factor_input,
- gene_synonyms_input=gene_synonyms_input)
+ data_dfs = get_dfs(gene_input=gene_input, protein_input=protein_input, complex_input=complex_input,
+ interaction_input=interaction_input, transcription_factor_input=transcription_factor_input,
+ gene_synonyms_input=gene_synonyms_input)
(protein_column_names, interaction_column_names1, interaction_column_names2, version, complex_columns) = \
get_column_names_for_db_version(
- dataDFs['complex_input'], dataDFs['interaction_input'], dataDFs['protein_input'])
+ data_dfs['complex_input'], data_dfs['interaction_input'], data_dfs['protein_input'])
# Perform sanity tests on *_input files and report any issues to the user as warnings
- run_sanity_tests(dataDFs, protein_column_names, version)
+ run_sanity_tests(data_dfs, protein_column_names, version)
# Collect protein data
- protein_db_df = dataDFs['protein_input'][['protein_name', 'tags', 'tags_reason', 'tags_description', 'uniprot']]
- num_proteins = protein_db_df.shape[0]
- multidata_id_list_so_far = list(range(num_proteins))
- protein_db_df.insert(0, 'id_protein', multidata_id_list_so_far, False)
- protein_db_df.insert(len(protein_db_df.columns), 'protein_multidata_id', protein_db_df['id_protein'].tolist(), True)
- # dbg(protein_db_df.info)
+ protein_db_df, multidata_id_list_so_far = collect_protein_data(data_dfs)
# Collect gene data
- gene_db_df = dataDFs['gene_input'][['ensembl', 'gene_name', 'hgnc_symbol', 'uniprot']]
- num_genes = gene_db_df.shape[0]
- gene_db_df.insert(0, 'id_gene', list(range(num_genes)), False)
- # Assign values from protein_db_df['protein_multidata_id'] into gene_db_df['protein_id']
- # via join between 'uniprot' and 'protein_name'
- gene_db_df = pd.merge(gene_db_df, protein_db_df[['protein_name', 'protein_multidata_id', 'uniprot']], on='uniprot')
- gene_db_df = gene_db_df.drop('uniprot', axis=1)
- protein_db_df = protein_db_df.drop('uniprot', axis=1)
- gene_db_df.rename(columns={'protein_multidata_id': 'protein_id'}, inplace=True)
- # print(gene_db_df.info)
+ gene_db_df = collect_gene_data(data_dfs, protein_db_df)
# Collect mapping: (receptor) gene name -> TF gene name (in transcription_factor_input.tsv)
- receptor_to_tf_df = None
- # Cater for DB version-dependent input files
- if dataDFs['transcription_factor_input'] is not None:
- receptor_to_tf_df = dataDFs['transcription_factor_input'][['receptor_id', 'TF_symbol']].copy()
- receptor_to_tf_df.columns = ['Receptor', 'TF']
- # Strip any leading or trailing spaces
- receptor_to_tf_df.replace(r'\s*(.*?)\s*', r'\1', regex=True, inplace=True)
+ receptor_to_tf_df = collect_receptor_to_tf_mapping(data_dfs)
# Collect mapping: gene synonym (not in gene_input.csv) -> gene name (in gene_input.csv)
- gene_synonym_to_gene_name_db_df = None
- # Cater for DB version-dependent input files
- if dataDFs['gene_synonyms_input'] is not None:
- gene_synonym_to_gene_name = {}
- for gene_names in dataDFs['gene_synonyms_input']\
- .filter(regex=("Gene Names.*")).dropna().agg(' '.join, axis=1).tolist():
- gene_names_arr = re.split(';\\s*|\\s+', gene_names)
- for gene_name in gene_db_df[gene_db_df['gene_name'].isin(gene_names_arr)]['gene_name'].tolist():
- for gene_synonym in gene_names_arr:
- if gene_synonym != gene_name:
- gene_synonym_to_gene_name[gene_synonym] = gene_name
- gene_synonym_to_gene_name_db_df = pd.DataFrame(gene_synonym_to_gene_name.items(),
- columns=['Gene Synonym', 'Gene Name'])
+ gene_synonym_to_gene_name_db_df = collect_gene_synonym_to_gene_name_mapping(data_dfs, gene_db_df)
# Collect complex data
cols = protein_column_names + ['pdb_structure', 'pdb_id', 'stoichiometry', 'comments_complex'] + complex_columns
- complex_db_df = dataDFs['complex_input'][cols]
+ complex_db_df = data_dfs['complex_input'][cols]
# Note that uniprot_* cols will be dropped after complex_composition_df has been constructed
num_complexes = complex_db_df.shape[0]
@@ -322,11 +343,11 @@ def create_db(target_dir) -> None:
# Collect multidata
# Insert proteins into multidata
multidata_db_df = \
- dataDFs['protein_input'][['uniprot', 'receptor', 'receptor_desc', 'other', 'other_desc', 'secreted_highlight',
+ data_dfs['protein_input'][['uniprot', 'receptor', 'receptor_desc', 'other', 'other_desc', 'secreted_highlight',
'secreted_desc', 'transmembrane', 'secreted', 'peripheral', 'integrin']].copy()
multidata_db_df.rename(columns={'uniprot': 'name'}, inplace=True)
multidata_ids = pd.merge(
- dataDFs['protein_input'][['protein_name']],
+ data_dfs['protein_input'][['protein_name']],
protein_db_df[['protein_name', 'protein_multidata_id']], on='protein_name')['protein_multidata_id'].tolist()
multidata_db_df.insert(0, 'id_multidata', multidata_ids, False)
multidata_db_df.insert(len(multidata_db_df.columns), 'is_complex',
@@ -335,7 +356,7 @@ def create_db(target_dir) -> None:
# Insert complexes into multidata
cols = ['complex_name', 'receptor', 'receptor_desc', 'other', 'other_desc', 'secreted_highlight', 'secreted_desc',
'transmembrane', 'secreted', 'peripheral', 'integrin']
- complex_aux_df = dataDFs['complex_input'][cols].copy()
+ complex_aux_df = data_dfs['complex_input'][cols].copy()
complex_aux_df.rename(columns={'complex_name': 'name'}, inplace=True)
complex_aux_df.insert(0, 'id_multidata', complex_multidata_ids, False)
complex_aux_df.insert(len(complex_aux_df.columns), 'is_complex',
@@ -371,7 +392,7 @@ def create_db(target_dir) -> None:
complex_db_df = complex_db_df.drop(col, axis=1)
# Collect interaction data
- interactions_aux_df = pd.merge(dataDFs['interaction_input'], multidata_db_df,
+ interactions_aux_df = pd.merge(data_dfs['interaction_input'], multidata_db_df,
left_on=['partner_a'], right_on=['name'])
interactions_aux_df = pd.merge(interactions_aux_df, multidata_db_df,
left_on=['partner_b'], right_on=['name'], suffixes=['_x', '_y'])
@@ -436,8 +457,8 @@ def download_released_files(target_dir, cpdb_version, regex):
print("Downloaded {} into {}".format(fname, target_dir))
-def getDFs(gene_input=None, protein_input=None, complex_input=None, interaction_input=None,
- transcription_factor_input=None, gene_synonyms_input=None):
+def get_dfs(gene_input=None, protein_input=None, complex_input=None, interaction_input=None,
+ transcription_factor_input=None, gene_synonyms_input=None):
dfs = {}
dfs['gene_input'] = file_utils.read_data_table_from_file(gene_input)
dfs['protein_input'] = file_utils.read_data_table_from_file(protein_input)
@@ -448,18 +469,7 @@ def getDFs(gene_input=None, protein_input=None, complex_input=None, interaction_
return dfs
-def run_sanity_tests(dataDFs, protein_column_names, version):
- data_errors_found = False
- protein_db_df = dataDFs['protein_input']
- complex_db_df = dataDFs['complex_input']
- gene_db_df = dataDFs['gene_input']
- interaction_db_df = dataDFs['interaction_input']
- tf_input_df = None
- # Cater for DB version-dependent input files
- if 'transcription_factor_input' in dataDFs:
- tf_input_df = dataDFs['transcription_factor_input']
-
- # 1. Report any uniprot accessions that map to multiple gene_names
+def sanity_test_uniprot_accessions_map_to_multiple_gene_names(gene_db_df: pd.DataFrame):
gene_names_uniprot_df = gene_db_df[['gene_name', 'uniprot']].copy()
gene_names_uniprot_df.drop_duplicates(inplace=True)
dups = gene_names_uniprot_df[gene_names_uniprot_df['uniprot'].duplicated() == True]
@@ -469,17 +479,16 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
"they should map to only one):")
print(", ".join(dups['uniprot'].tolist()))
- # 2. Warn about complex name duplicates in complex_db_df
+
+def sanity_test_report_complex_duplicates(complex_db_df: pd.DataFrame):
test_complex_db_df = complex_db_df.set_index('complex_name')
if test_complex_db_df.index.has_duplicates:
print("WARNING: complex_input.csv has the following duplicates:")
print("\n".join(complex_db_df[test_complex_db_df.index.duplicated(keep='first')]['complex_name'].tolist()) + "\n")
- # 3. Report complexes with (possibly) different names, but with the same uniprot
- # accession participants (though not necessarily in the same order - hence the use of set below)
- # NB. Use set below as we don't care about the order of participants when looking for duplicates
- # NB. Report duplicate complexes _only if_ at least one duplicate's complex_db_df['version']
- # does not start with CORE_CELLPHONEDB_DATA)
+
+def sanity_test_report_complexes_with_same_participants(
+ complex_db_df: pd.DataFrame, version: str, protein_column_names: list):
participants_set_to_complex_names = {}
participants_set_to_data_sources = {}
cols = ['complex_name'] + version + protein_column_names
@@ -511,8 +520,8 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
print("WARNING: The following multiple complexes (left) appear to have the same composition (right):")
print(complex_dups)
- # 4. Report interactions with (possibly) a different name, but with the same participants
- # (though not necessarily in the same order - hence the use of set below)
+
+def sanity_test_report_interactions_with_same_participants(interaction_db_df: pd.DataFrame):
partner_sets = [set([i for i in row]) for row in
interaction_db_df[['partner_a', 'partner_b']].itertuples(index=False)]
# Find duplicate sets of partners
@@ -525,13 +534,15 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
print(','.join(dup))
print()
- # 5. Warn about uniprot accession duplicates in protein_db_df
+
+def sanity_test_report_uniprot_accession_duplicates(protein_db_df: pd.DataFrame):
test_protein_db_df = protein_db_df.set_index('uniprot')
if test_protein_db_df.index.has_duplicates:
print("WARNING: protein_input.csv has the following UniProt accession duplicates:")
print("\n".join(protein_db_df[test_protein_db_df.index.duplicated(keep='first')]['uniprot'].tolist()) + "\n")
- # 6. Warn the user if some complexes don't participate in any interactions
+
+def sanity_test_report_orphan_complexes(complex_db_df: pd.DataFrame, interaction_db_df: pd.DataFrame) -> set:
all_complexes_set = set(complex_db_df['complex_name'].tolist())
interaction_participants_set = set(interaction_db_df['partner_a'].tolist() + interaction_db_df['partner_b'].tolist())
orphan_complexes = all_complexes_set - interaction_participants_set
@@ -539,10 +550,18 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
print("WARNING: The following complexes are not found in interaction_input.txt:")
print("\n".join(orphan_complexes))
print()
+ return orphan_complexes
- # 7. Warn the user if some proteins don't participate in any interactions directly,
- # or are part some complex in orphan_complexes
+
+def sanity_test_report_orphan_proteins(
+ protein_db_df: pd.DataFrame,
+ complex_db_df: pd.DataFrame,
+ interaction_db_df: pd.DataFrame,
+ protein_column_names: list,
+ orphan_complexes: set):
all_proteins_set = set(protein_db_df['uniprot'].tolist())
+ interaction_participants_set = set(
+ interaction_db_df['partner_a'].tolist() + interaction_db_df['partner_b'].tolist())
proteins_in_complexes_participating_in_interactions = []
for colName in protein_column_names:
proteins_in_complexes_participating_in_interactions += \
@@ -554,21 +573,30 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
"or via complexes they are part of):")
print("\n".join(orphan_proteins))
- # 8. Warn the user if some interactions contain interactors that are neither
- # in complex_input.csv or protein_input.csv
+
+def sanity_test_report_unknown_interactors(
+ protein_db_df: pd.DataFrame,
+ complex_db_df: pd.DataFrame,
+ interaction_db_df: pd.DataFrame
+):
unknown_interactors = set()
for col in ['partner_a', 'partner_b']:
aux_df = pd.merge(interaction_db_df, protein_db_df, left_on=col, right_on='uniprot', how='outer')
unknown_interactor_proteins = set(aux_df[pd.isnull(aux_df['uniprot'])][col].tolist())
aux_df = pd.merge(interaction_db_df, complex_db_df, left_on=col, right_on='complex_name', how='outer')
unknown_interactor_complexes = set(aux_df[pd.isnull(aux_df['complex_name'])][col].tolist())
- unknown_interactors = unknown_interactors.union(unknown_interactor_proteins.intersection(unknown_interactor_complexes))
+ unknown_interactors = unknown_interactors.union(
+ unknown_interactor_proteins.intersection(unknown_interactor_complexes))
if unknown_interactors:
print("WARNING: The following interactors in interaction_input.txt could not be found in either " +
"protein_input.csv or complex_indput.csv:")
print("\n".join(sorted(unknown_interactors)) + "\n")
- # 9. Warn if some complexes contain proteins not in protein_input.csv
+
+def sanity_test_report_unknown_proteins(
+ protein_db_df: pd.DataFrame,
+ complex_db_df: pd.DataFrame,
+ protein_column_names: list):
unknown_proteins = set()
for col in protein_column_names:
aux_df = pd.merge(complex_db_df, protein_db_df, left_on=col, right_on='uniprot', how='outer')
@@ -578,7 +606,10 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
print("WARNING: The following proteins in complex_input.txt could not be found in protein_input.csv:")
print("\n".join(sorted(unknown_proteins)) + "\n")
- # 10. Warn if some proteins in protein_input.csv are not in gene_input.csv
+
+def sanity_test_report_proteins_not_in_genes_file(
+ protein_db_df: pd.DataFrame,
+ gene_db_df: pd.DataFrame):
proteins = set(protein_db_df['uniprot'].tolist())
unknown_proteins = proteins.difference(set(gene_db_df['uniprot'].tolist()))
if unknown_proteins:
@@ -586,7 +617,11 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
print("\n".join(sorted(unknown_proteins)) + "\n")
print()
- # 11. Warn if some receptor ids in tf_input_df are in neither gene_input.csv or complex_input.csv
+
+def sanity_test_report_tfs_not_in_gene_or_complex_files(
+ gene_db_df: pd.DataFrame,
+ complex_db_df: pd.DataFrame,
+ tf_input_df: pd.DataFrame):
if tf_input_df is not None:
# Cater for DB version-dependent input files
for (bioentity, df) in {"gene": gene_db_df, "complex": complex_db_df}.items():
@@ -604,5 +639,57 @@ def run_sanity_tests(dataDFs, protein_column_names, version):
print("\n".join(set(bioentities_not_in_input)))
print()
+
+def run_sanity_tests(data_dfs, protein_column_names, version):
+ data_errors_found = False
+ protein_db_df = data_dfs['protein_input']
+ complex_db_df = data_dfs['complex_input']
+ gene_db_df = data_dfs['gene_input']
+ interaction_db_df = data_dfs['interaction_input']
+ tf_input_df = None
+ # Cater for DB version-dependent input files
+ if 'transcription_factor_input' in data_dfs:
+ tf_input_df = data_dfs['transcription_factor_input']
+
+ # 1. Report any uniprot accessions that map to multiple gene_names
+ sanity_test_uniprot_accessions_map_to_multiple_gene_names(gene_db_df)
+
+ # 2. Warn about complex name duplicates in complex_db_df
+ sanity_test_report_complex_duplicates(complex_db_df)
+
+ # 3. Report complexes with (possibly) different names, but with the same uniprot
+ # accession participants (though not necessarily in the same order - hence the use of set below)
+ # NB. Use set below as we don't care about the order of participants when looking for duplicates
+ # NB. Report duplicate complexes _only if_ at least one duplicate's complex_db_df['version']
+ # does not start with CORE_CELLPHONEDB_DATA)
+ sanity_test_report_complexes_with_same_participants(complex_db_df, version, protein_column_names)
+
+ # 4. Report interactions with (possibly) a different name, but with the same participants
+ # (though not necessarily in the same order - hence the use of set below)
+ sanity_test_report_interactions_with_same_participants(interaction_db_df)
+
+ # 5. Warn about uniprot accession duplicates in protein_db_df
+ sanity_test_report_uniprot_accession_duplicates(protein_db_df)
+
+ # 6. Warn the user if some complexes don't participate in any interactions
+ orphan_complexes = sanity_test_report_orphan_complexes(complex_db_df, interaction_db_df)
+
+ # 7. Warn the user if some proteins don't participate in any interactions directly,
+ # or are part of some complex in orphan_complexes
+ sanity_test_report_orphan_proteins(protein_db_df, complex_db_df, interaction_db_df, protein_column_names, orphan_complexes)
+
+ # 8. Warn the user if some interactions contain interactors that are neither
+ # in complex_input.csv or protein_input.csv
+ sanity_test_report_unknown_interactors(protein_db_df, complex_db_df, interaction_db_df)
+
+ # 9. Warn if some complexes contain proteins not in protein_input.csv
+ sanity_test_report_unknown_proteins(protein_db_df, complex_db_df, protein_column_names)
+
+ # 10. Warn if some proteins in protein_input.csv are not in gene_input.csv
+ sanity_test_report_proteins_not_in_genes_file(protein_db_df, gene_db_df)
+
+ # 11. Warn if some receptor ids in tf_input_df are in neither gene_input.csv or complex_input.csv
+ sanity_test_report_tfs_not_in_gene_or_complex_files(gene_db_df, complex_db_df, tf_input_df)
+
if data_errors_found:
raise DatabaseCreationException
diff --git a/cellphonedb/utils/file_utils.py b/cellphonedb/utils/file_utils.py
index 16e4792..10db7c4 100755
--- a/cellphonedb/utils/file_utils.py
+++ b/cellphonedb/utils/file_utils.py
@@ -28,17 +28,8 @@ def read_data_table_from_file(file: str, index_column_first: bool = False, separ
return _read_h5ad(file)
if file_extension == '.h5':
return _read_h5(file)
-
if file_extension == '.pickle':
- try:
- with open(file, 'rb') as f:
- df = pickle.load(f)
- if isinstance(df, pd.DataFrame):
- return df
- else:
- raise NotADataFrameException(file)
- except Exception:
- raise ReadFromPickleException(file)
+ return _read_pickle(file)
if not separator:
separator = _get_separator(file_extension)
@@ -111,6 +102,18 @@ def _read_h5(path: str) -> pd.DataFrame:
return df
+def _read_pickle(path: str) -> pd.DataFrame:
+ try:
+ with open(path, 'rb') as f:
+ df = pickle.load(f)
+ if isinstance(df, pd.DataFrame):
+ return df
+ else:
+ raise NotADataFrameException(path)
+ except Exception:
+ raise ReadFromPickleException(path)
+
+
def _read_data(file_stream: TextIO, separator: str, index_column_first: bool, dtype=None,
na_values=None, compression=None) -> pd.DataFrame:
return pd.read_csv(file_stream, sep=separator, index_col=0 if index_column_first else None, dtype=dtype,
diff --git a/cellphonedb/utils/search_utils.py b/cellphonedb/utils/search_utils.py
index 25e0bdc..d384ac3 100644
--- a/cellphonedb/utils/search_utils.py
+++ b/cellphonedb/utils/search_utils.py
@@ -29,6 +29,44 @@ def populate_proteins_for_complex(complex_name, complex_name2proteins, genes, co
complex_name2proteins[complex_name] = constituent_proteins
+def assemble_multidata_ids_for_search(
+ query_str: str,
+ genes: pd.DataFrame,
+ complex_expanded: pd.DataFrame,
+ complex_composition: pd.DataFrame,
+ gene_synonym2gene_name: dict) -> list:
+ multidata_ids = []
+ for token in re.split(',\\s*| ', query_str):
+
+ if token in gene_synonym2gene_name:
+ # Map any gene synonyms not in gene_input to gene names in gene_input
+ token = gene_synonym2gene_name[token]
+
+ complex_multidata_ids = []
+ # Attempt to find token in genes (N.B. genes contains protein information also)
+ gene_protein_data_list = \
+ genes['protein_multidata_id'][
+ genes[['ensembl', 'gene_name', 'name', 'protein_name']]
+ .apply(lambda row: row.astype(str).eq(token).any(), axis=1)
+ ].to_list()
+ if (len(gene_protein_data_list) > 0):
+ multidata_ids += gene_protein_data_list
+ for protein_multidata_id in gene_protein_data_list:
+ complex_multidata_ids = \
+ complex_composition['complex_multidata_id'][complex_composition['protein_multidata_id']
+ == protein_multidata_id].to_list()
+ multidata_ids += complex_multidata_ids
+ else:
+ # No match in genes - attempt to find token in complex_expanded
+ complex_multidata_ids += \
+ complex_expanded['complex_multidata_id'][
+ complex_expanded[['name']]
+ .apply(lambda row: row.astype(str).eq(token).any(), axis=1)
+ ].to_list()
+ multidata_ids += complex_multidata_ids
+ return multidata_ids
+
+
def search(query_str: str = "",
cpdb_file_path: str = None) -> (list, map, map, map, map):
"""
@@ -64,35 +102,9 @@ def search(query_str: str = "",
complex_name2proteins = {}
# Assemble a list of multidata_ids to search interactions DF with
- multidata_ids = []
- for token in re.split(',\\s*| ', query_str):
-
- if token in gene_synonym2gene_name:
- # Map any gene synonyms not in gene_input to gene names in gene_input
- token = gene_synonym2gene_name[token]
+ multidata_ids = assemble_multidata_ids_for_search(
+ query_str, genes, complex_expanded, complex_composition, gene_synonym2gene_name)
- complex_multidata_ids = []
- # Attempt to find token in genes (N.B. genes contains protein information also)
- gene_protein_data_list = \
- genes['protein_multidata_id'][
- genes[['ensembl', 'gene_name', 'name', 'protein_name']]
- .apply(lambda row: row.astype(str).eq(token).any(), axis=1)
- ].to_list()
- if (len(gene_protein_data_list) > 0):
- multidata_ids += gene_protein_data_list
- for protein_multidata_id in gene_protein_data_list:
- complex_multidata_ids = \
- complex_composition['complex_multidata_id'][complex_composition['protein_multidata_id']
- == protein_multidata_id].to_list()
- multidata_ids += complex_multidata_ids
- else:
- # No match in genes - attempt to find token in complex_expanded
- complex_multidata_ids += \
- complex_expanded['complex_multidata_id'][
- complex_expanded[['name']]
- .apply(lambda row: row.astype(str).eq(token).any(), axis=1)
- ].to_list()
- multidata_ids += complex_multidata_ids
# Now search for all multidata_ids in interactions
duration = time.time() - start
dbg("Output for query '{}':".format(query_str))
@@ -314,6 +326,28 @@ def return_all_identifiers(genes: pd.DataFrame, interactions: pd.DataFrame) -> p
return result
+def collect_celltype_pairs(
+ significant_means: pd.DataFrame,
+ query_cell_types_1: list,
+ query_cell_types_2: list,
+ separator: str) -> list:
+ if query_cell_types_1 is None or query_cell_types_2 is None:
+ cols_filter = significant_means.filter(regex="\\{}".format(separator)).columns
+ all_cts = set([])
+ for ct_pair in [i.split(separator) for i in cols_filter.tolist()]:
+ all_cts |= set(ct_pair)
+ all_cell_types = list(all_cts)
+ if query_cell_types_1 is None:
+ query_cell_types_1 = all_cell_types
+ if query_cell_types_2 is None:
+ query_cell_types_2 = all_cell_types
+ cell_type_pairs = []
+ for ct in query_cell_types_1:
+ for ct1 in query_cell_types_2:
+ cell_type_pairs += ["{}{}{}".format(ct, separator, ct1), "{}{}{}".format(ct1, separator, ct)]
+ return cell_type_pairs
+
+
def search_analysis_results(
query_cell_types_1: list = None,
query_cell_types_2: list = None,
@@ -370,20 +404,7 @@ def search_analysis_results(
return
# Collect all combinations of cell types (disregarding the order) from query_cell_types_1 and query_cell_types_2
- if query_cell_types_1 is None or query_cell_types_2 is None:
- cols_filter = significant_means.filter(regex="\\{}".format(separator)).columns
- all_cts = set([])
- for ct_pair in [i.split(separator) for i in cols_filter.tolist()]:
- all_cts |= set(ct_pair)
- all_cell_types = list(all_cts)
- if query_cell_types_1 is None:
- query_cell_types_1 = all_cell_types
- if query_cell_types_2 is None:
- query_cell_types_2 = all_cell_types
- cell_type_pairs = []
- for ct in query_cell_types_1:
- for ct1 in query_cell_types_2:
- cell_type_pairs += ["{}{}{}".format(ct, separator, ct1), "{}{}{}".format(ct1, separator, ct)]
+ cell_type_pairs = collect_celltype_pairs(significant_means, query_cell_types_1, query_cell_types_2, separator)
cols_filter = significant_means.columns[significant_means.columns.isin(cell_type_pairs)]
# Collect all interactions from query_genes and query_interactions