diff --git a/src/pypath/network.py b/src/pypath/network.py index 31d31a291..a8d58925c 100644 --- a/src/pypath/network.py +++ b/src/pypath/network.py @@ -43,6 +43,7 @@ import pypath.dataio as dataio import pypath.curl as curl import pypath.refs as refs_mod +import pypath.reflists as reflists import pypath.resources.network as network_resources @@ -365,6 +366,9 @@ class Network(session_mod.Logger): :arg bool make_df: Create a ``pandas.DataFrame`` already when creating the instance. If no network data loaded no data frame will be created. + :arg int ncbi_tax_id: + Restrict the network only to this organism. If ``None`` identifiers + from any organism will be allowed. """ _partners_methods = ( @@ -417,6 +421,7 @@ def __init__( df_columns = None, df_dtype = None, pickle_file = None, + ncbi_tax_id = 9606, **kwargs ): @@ -428,6 +433,7 @@ def __init__( self.df_with_references = df_with_references self.df_columns = df_columns self.df_dtype = df_dtype + self.ncbi_tax_id = ncbi_tax_id self.cache_dir = cache_mod.get_cachedir() self.keep_original_names = settings.get('network_keep_original_names') @@ -648,6 +654,8 @@ def load_resource( ) self._add_edge_list(only_directions = only_directions) + self.organisms_check() + self._log( 'Completed: loading network data from ' 'resource `%s`.' % resource.name @@ -1851,6 +1859,84 @@ def _add_update_edge( ) + def organisms_check( + self, + organisms = None, + remove_mismatches = True, + remove_nonspecific = False, + ): + """ + Scans the network for one or more organisms and removes the nodes + and interactions which belong to any other organism. + + :arg int,set,NoneType organisms: + One or more NCBI Taxonomy IDs. If ``None`` the value in + :py:attr:`ncbi_tax_id` will be used. If that's too is ``None`` + then only the entities with discrepancy between their stated + organism and their identifier. + :arg bool remove_mismatches: + Remove the entities where their ``identifier`` can not be found + in the reference list from the database for their ``taxon``. + :arg bool remove_nonspecific: + Remove the entities with taxonomy ID zero, which is used to + represent the non taxon specific entities such as metabolites + or drug compounds. + """ + + self._log( + 'Checking organisms. %u nodes and %u interactions before.' % ( + self.vcount, + self.ecount, + ) + ) + + organisms = common.to_set(organisms or self.ncbi_tax_id) + + to_remove = set() + + for node in self.nodes.values(): + + if organisms and node.taxon not in organisms: + + to_remove.add(node) + + if ( + ( + remove_mismatches and + not reflists.check( + name = node.identifier, + id_type = node.id_type, + ncbi_tax_id = node.taxon, + ) + ) or ( + remove_nonspecific and + not node.taxon + ) + ): + + to_remove.add(node) + + for node in to_remove: + + self.remove_node(node) + + self._log( + 'Finished checking organisms. ' + '%u nodes and %u interactions remained.' % ( + self.vcount, + self.ecount, + ) + ) + + + def get_organisms(self): + """ + Returns the set of all NCBI Taxonomy IDs occurring in the network. + """ + + return {n.taxon for n in self.nodes.values()} + + @property def vcount(self): @@ -2064,6 +2150,14 @@ def add_node(self, entity, attrs = None, add = True): def remove_node(self, entity): + """ + Removes a node with all its interactions. + If the removal of the interactions leaves any of the partner nodes + without interactions it will be removed too. + + :arg str,Entity entity: + A molecular entity identifier, label or ``Entity`` object. + """ entity = self.entity(entity) @@ -2076,14 +2170,23 @@ def remove_node(self, entity): if entity in self.interactions_by_nodes: - for i_key in self.interactions_by_nodes[entity]: + partners = set() + + for i_key in self.interactions_by_nodes[entity].copy(): - _ = self.interactions.pop(i_key, None) + self.remove_interaction(*i_key) - del self.interactions_by_nodes[entity] + _ = self.interactions_by_nodes.pop(entity, None) def remove_interaction(self, entity_a, entity_b): + """ + Removes the interaction between two nodes if exists. + + :arg str,Entity entity_a,entity_b: + A pair of molecular entity identifiers, labels or ``Entity`` + objects. + """ entity_a = self.entity(entity_a) entity_b = self.entity(entity_b) @@ -2098,11 +2201,17 @@ def remove_interaction(self, entity_a, entity_b): self.interactions_by_nodes[entity_a] -= keys self.interactions_by_nodes[entity_b] -= keys - if not self.interactions_by_nodes[entity_a]: + if ( + entity_a in self.interactions_by_nodes and + not self.interactions_by_nodes[entity_a] + ): self.remove_node(entity_a) - if not self.interactions_by_nodes[entity_b]: + if ( + entity_b in self.interactions_by_nodes and + not self.interactions_by_nodes[entity_b] + ): self.remove_node(entity_b) @@ -2616,12 +2725,13 @@ def omnipath( min_refs_undirected = 2, old_omnipath_resources = False, exclude = None, + ncbi_tax_id = 9606, **kwargs ): make_df = kwargs.pop('make_df', None) - new = cls(**kwargs) + new = cls(ncbi_tax_id = ncbi_tax_id, **kwargs) new.load_omnipath( omnipath = omnipath, @@ -2655,7 +2765,7 @@ def load_dorothea(self, levels = None, **kwargs): @classmethod - def dorothea(cls, levels = None, **kwargs): + def dorothea(cls, levels = None, ncbi_tax_id = 9606, **kwargs): """ Initializes a new ``Network`` object with loading the transcriptional regulation network from DoRothEA. @@ -2666,7 +2776,7 @@ def dorothea(cls, levels = None, **kwargs): make_df = kwargs.pop('make_df', False) - new = cls(**kwargs) + new = cls(ncbi_tax_id = ncbi_tax_id, **kwargs) new.load_dorothea(levels = levels, make_df = make_df) @@ -2724,6 +2834,7 @@ def transcription( reread = False, redownload = False, make_df = False, + ncbi_tax_id = 9606, **kwargs ): """ @@ -2735,6 +2846,8 @@ def transcription( load_args = locals() kwargs = load_args.pop('kwargs') + ncbi_tax_id = load_args.pop('ncbi_tax_id') + kwargs['ncbi_tax_id'] = ncbi_tax_id cls = load_args.pop('cls') new = cls(**kwargs) @@ -2761,6 +2874,7 @@ def mirna_target( reread = False, redownload = False, exclude = None, + ncbi_tax_id = 9606, **kwargs ): """ @@ -2770,7 +2884,7 @@ def mirna_target( **kwargs: passed to ``Network.__init__``. """ - new = cls(**kwargs) + new = cls(ncbi_tax_id = ncbi_tax_id, **kwargs) new.mirna_target( exclude = exclude,