Skip to content

Commit

Permalink
Convert conn.drop() to net.connection_probability
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley authored and jasmainak committed May 19, 2021
1 parent 3dd9a50 commit 74a9704
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 58 deletions.
108 changes: 54 additions & 54 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,12 +1075,64 @@ def add_connection(self, src_gids, target_gids, loc, receptor,
_validate_type(item, (int, float), arg_name, 'int or float')
conn['nc_dict'][key] = item

# Probabilistically define connections
if probability != 1.0:
conn.drop(probability)
conn['probability'] = probability
self.connection_probability(conn, probability)

self.connectivity.append(deepcopy(conn))

def connection_probability(self, conn, probability):
"""Remove/keep a random subset of connections.
Parameters
----------
probability : float
Probability of connection between any src-target pair.
Defaults to 1.0 producing an all-to-all pattern.
Notes
-----
num_srcs and num_targets are not updated after pruning connections.
These variables are meant to describe the set of original connections
before they are randomly removed.
The probability attribute will store the most recent value passed to
this function. As such, this number does not accurately describe the
connections probability of the original set after successive calls.
"""
_validate_type(probability, float, 'probability')
if probability <= 0.0 or probability >= 1.0:
raise ValueError('probability must be in the range (0,1)')
# Flatten connections into a list of targets.
all_connections = np.concatenate(
[target_src_pair for
target_src_pair in conn['gid_pairs'].values()])
n_connections = np.round(
len(all_connections) * probability).astype(int)

# Select a random subset of connections to retain.
new_connections = np.random.choice(
range(len(all_connections)), n_connections, replace=False)
remove_srcs = list()
connection_idx = 0
for src_gid, target_src_pair in conn['gid_pairs'].items():
target_new = list()
for target_gid in target_src_pair:
if connection_idx in new_connections:
target_new.append(target_gid)
connection_idx += 1

# Update targets for src_gid
if target_new:
conn['gid_pairs'][src_gid] = target_new
else:
remove_srcs.append(src_gid)
# Remove src_gids with no targets
for src_gid in remove_srcs:
conn['gid_pairs'].pop(src_gid)

conn['probability'] = probability

def clear_connectivity(self):
"""Remove all connections defined in Network.connectivity_list"""
self.connectivity = list()
Expand Down Expand Up @@ -1189,58 +1241,6 @@ def __repr__(self):

return entr

def drop(self, probability):
"""Remove/keep a random subset of connections.
Parameters
----------
probability : float
Probability of connection between any src-target pair.
Defaults to 1.0 producing an all-to-all pattern.
Notes
-----
num_srcs and num_targets are not updated after pruning connections.
These variables are meant to describe the set of original connections
before they are randomly removed.
The probability attribute will store the most recent value passed to
this function. As such, this number does not accurately describe the
connections probability of the original set after successive calls.
"""
_validate_type(probability, float, 'probability')
if probability <= 0.0 or probability >= 1.0:
raise ValueError('probability must be in the range (0,1)')
# Flatten connections into a list of targets.
all_connections = np.concatenate(
[target_src_pair for
target_src_pair in self['gid_pairs'].values()])
n_connections = np.round(
len(all_connections) * probability).astype(int)

# Select a random subset of connections to retain.
new_connections = np.random.choice(
range(len(all_connections)), n_connections, replace=False)
remove_srcs = list()
connection_idx = 0
for src_gid, target_src_pair in self['gid_pairs'].items():
target_new = list()
for target_gid in target_src_pair:
if connection_idx in new_connections:
target_new.append(target_gid)
connection_idx += 1

# Update targets for src_gid
if target_new:
self['gid_pairs'][src_gid] = target_new
else:
remove_srcs.append(src_gid)
# Remove src_gids with no targets
for src_gid in remove_srcs:
self['gid_pairs'].pop(src_gid)

self['probability'] = probability

def plot(self, ax=None, show=True):
"""Plot connectivity matrix for instance of _Connectivity object.
Expand Down
12 changes: 8 additions & 4 deletions hnn_core/tests/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,15 +228,19 @@ def test_network():
kwargs[arg] = string_arg
net.add_connection(**kwargs)

# Check probability=0.5 produces half as many connections as default
net.add_connection(**kwargs_default)
kwargs = kwargs_default.copy()
kwargs['probability'] = 0.5
net.add_connection(**kwargs)
n_connections = np.sum(
[len(t_gids) for
t_gids in net.connectivity[0]['gid_pairs'].values()])
net.connectivity[0].drop(0.5)
t_gids in net.connectivity[-2]['gid_pairs'].values()])
n_connections_new = np.sum(
[len(t_gids) for
t_gids in net.connectivity[0]['gid_pairs'].values()])
t_gids in net.connectivity[-1]['gid_pairs'].values()])
assert n_connections_new == np.round(n_connections * 0.5).astype(int)
assert net.connectivity[0]['probability'] == 0.5
assert net.connectivity[-1]['probability'] == 0.5
with pytest.raises(ValueError, match='probability must be'):
kwargs = kwargs_default.copy()
kwargs['probability'] = -1.0
Expand Down

0 comments on commit 74a9704

Please sign in to comment.