From 74a97045cab1043c10013f1d21c8b1803d794a5a Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 14 May 2021 15:15:48 -0400 Subject: [PATCH] Convert conn.drop() to net.connection_probability --- hnn_core/network.py | 108 ++++++++++++++++----------------- hnn_core/tests/test_network.py | 12 ++-- 2 files changed, 62 insertions(+), 58 deletions(-) diff --git a/hnn_core/network.py b/hnn_core/network.py index a9dd81836..a935a33b3 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -1075,12 +1075,64 @@ def add_connection(self, src_gids, target_gids, loc, receptor, _validate_type(item, (int, float), arg_name, 'int or float') conn['nc_dict'][key] = item + # Probabilistically define connections if probability != 1.0: - conn.drop(probability) - conn['probability'] = probability + self.connection_probability(conn, probability) self.connectivity.append(deepcopy(conn)) + def connection_probability(self, conn, probability): + """Remove/keep a random subset of connections. + + Parameters + ---------- + probability : float + Probability of connection between any src-target pair. + Defaults to 1.0 producing an all-to-all pattern. + + Notes + ----- + num_srcs and num_targets are not updated after pruning connections. + These variables are meant to describe the set of original connections + before they are randomly removed. + + The probability attribute will store the most recent value passed to + this function. As such, this number does not accurately describe the + connections probability of the original set after successive calls. + """ + _validate_type(probability, float, 'probability') + if probability <= 0.0 or probability >= 1.0: + raise ValueError('probability must be in the range (0,1)') + # Flatten connections into a list of targets. + all_connections = np.concatenate( + [target_src_pair for + target_src_pair in conn['gid_pairs'].values()]) + n_connections = np.round( + len(all_connections) * probability).astype(int) + + # Select a random subset of connections to retain. + new_connections = np.random.choice( + range(len(all_connections)), n_connections, replace=False) + remove_srcs = list() + connection_idx = 0 + for src_gid, target_src_pair in conn['gid_pairs'].items(): + target_new = list() + for target_gid in target_src_pair: + if connection_idx in new_connections: + target_new.append(target_gid) + connection_idx += 1 + + # Update targets for src_gid + if target_new: + conn['gid_pairs'][src_gid] = target_new + else: + remove_srcs.append(src_gid) + # Remove src_gids with no targets + for src_gid in remove_srcs: + conn['gid_pairs'].pop(src_gid) + + conn['probability'] = probability + def clear_connectivity(self): """Remove all connections defined in Network.connectivity_list""" self.connectivity = list() @@ -1189,58 +1241,6 @@ def __repr__(self): return entr - def drop(self, probability): - """Remove/keep a random subset of connections. - - Parameters - ---------- - probability : float - Probability of connection between any src-target pair. - Defaults to 1.0 producing an all-to-all pattern. - - Notes - ----- - num_srcs and num_targets are not updated after pruning connections. - These variables are meant to describe the set of original connections - before they are randomly removed. - - The probability attribute will store the most recent value passed to - this function. As such, this number does not accurately describe the - connections probability of the original set after successive calls. - """ - _validate_type(probability, float, 'probability') - if probability <= 0.0 or probability >= 1.0: - raise ValueError('probability must be in the range (0,1)') - # Flatten connections into a list of targets. - all_connections = np.concatenate( - [target_src_pair for - target_src_pair in self['gid_pairs'].values()]) - n_connections = np.round( - len(all_connections) * probability).astype(int) - - # Select a random subset of connections to retain. - new_connections = np.random.choice( - range(len(all_connections)), n_connections, replace=False) - remove_srcs = list() - connection_idx = 0 - for src_gid, target_src_pair in self['gid_pairs'].items(): - target_new = list() - for target_gid in target_src_pair: - if connection_idx in new_connections: - target_new.append(target_gid) - connection_idx += 1 - - # Update targets for src_gid - if target_new: - self['gid_pairs'][src_gid] = target_new - else: - remove_srcs.append(src_gid) - # Remove src_gids with no targets - for src_gid in remove_srcs: - self['gid_pairs'].pop(src_gid) - - self['probability'] = probability - def plot(self, ax=None, show=True): """Plot connectivity matrix for instance of _Connectivity object. diff --git a/hnn_core/tests/test_network.py b/hnn_core/tests/test_network.py index e682587ea..21a77f000 100644 --- a/hnn_core/tests/test_network.py +++ b/hnn_core/tests/test_network.py @@ -228,15 +228,19 @@ def test_network(): kwargs[arg] = string_arg net.add_connection(**kwargs) + # Check probability=0.5 produces half as many connections as default + net.add_connection(**kwargs_default) + kwargs = kwargs_default.copy() + kwargs['probability'] = 0.5 + net.add_connection(**kwargs) n_connections = np.sum( [len(t_gids) for - t_gids in net.connectivity[0]['gid_pairs'].values()]) - net.connectivity[0].drop(0.5) + t_gids in net.connectivity[-2]['gid_pairs'].values()]) n_connections_new = np.sum( [len(t_gids) for - t_gids in net.connectivity[0]['gid_pairs'].values()]) + t_gids in net.connectivity[-1]['gid_pairs'].values()]) assert n_connections_new == np.round(n_connections * 0.5).astype(int) - assert net.connectivity[0]['probability'] == 0.5 + assert net.connectivity[-1]['probability'] == 0.5 with pytest.raises(ValueError, match='probability must be'): kwargs = kwargs_default.copy() kwargs['probability'] = -1.0