diff --git a/cellphonedb/utils/db_utils.py b/cellphonedb/utils/db_utils.py index 06f48f2d..6a39724f 100644 --- a/cellphonedb/utils/db_utils.py +++ b/cellphonedb/utils/db_utils.py @@ -39,17 +39,18 @@ def get_protein_and_complex_data_for_web(cpdb_file_path) -> Tuple[dict, dict, di proteinTable = dbTableDFs['protein_table'] for col in set(PROTEIN_INFO_FIELDS_FOR_WEB + COMPLEX_INFO_FIELDS_FOR_WEB): - mtTable.loc[mtTable[col] == True, col] = col.capitalize() - mtTable.loc[mtTable[col] == False, col] = np.nan + mtTable = mtTable.astype({col: 'str'}) + mtTable.loc[mtTable[col] == "True", col] = col.capitalize() + mtTable.loc[mtTable[col] == "False", col] = np.nan if col in ['other_desc']: # Sanitize values for displaying to the user mtTable[col] = mtTable[col].str.replace("_", " ").str.capitalize() - mtp = mtTable[mtTable['is_complex'] == False] + mtp = mtTable[~mtTable['is_complex']] aux = pd.merge(mtp, proteinTable, left_on='id_multidata', right_on='protein_multidata_id') proteinAcc2Name = dict(zip(aux['name'], aux['protein_name'])) - mtc = mtTable[mtTable['is_complex'] == True] + mtc = mtTable[mtTable['is_complex']] aux = dict(zip(mtp['name'], mtp[PROTEIN_INFO_FIELDS_FOR_WEB].values)) protein2Info = {k: [x for x in aux[k] if str(x) != 'nan'] for k in aux} @@ -377,7 +378,7 @@ def create_db(target_dir) -> None: for r in complex_db_df[protein_column_names + ['complex_multidata_id', 'total_protein']].values.tolist(): for acc in filter(lambda x: isinstance(x, str), r): protein_multidata_id = \ - multidata_db_df.loc[(multidata_db_df['is_complex'] == False) & + multidata_db_df.loc[(~multidata_db_df['is_complex']) & (multidata_db_df['name'] == acc), ['id_multidata']].iat[0, 0] complex_multidata_id = r[pos] total_protein = r[pos+1] @@ -472,7 +473,7 @@ def get_dfs(gene_input=None, protein_input=None, complex_input=None, interaction def sanity_test_uniprot_accessions_map_to_multiple_gene_names(gene_db_df: pd.DataFrame): gene_names_uniprot_df = gene_db_df[['gene_name', 'uniprot']].copy() gene_names_uniprot_df.drop_duplicates(inplace=True) - dups = gene_names_uniprot_df[gene_names_uniprot_df['uniprot'].duplicated() == True] + dups = gene_names_uniprot_df[gene_names_uniprot_df['uniprot'].duplicated()] if not dups.empty: # data_errors_found = True print("WARNING: The following UniProt ids map to multiple gene names (it is expected that " +