Skip to content

Commit

Permalink
Add seeding for probabilistic connections
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley authored and jasmainak committed May 19, 2021
1 parent ba84928 commit 07b3990
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions hnn_core/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,7 +958,7 @@ def _all_to_all_connect(self, src_cell, target_cell,
nc_dict['A_weight'], nc_dict['A_delay'], nc_dict['lamtha'])

def add_connection(self, src_gids, target_gids, loc, receptor,
weight, delay, lamtha, probability=1.0):
weight, delay, lamtha, probability=1.0, seed=0):
"""Appends connections to connectivity list
Parameters
Expand Down Expand Up @@ -1094,13 +1094,13 @@ def add_connection(self, src_gids, target_gids, loc, receptor,

# Probabilistically define connections
if probability != 1.0:
self._connection_probability(conn, probability)
self._connection_probability(conn, probability, seed)

conn['probability'] = probability

self.connectivity.append(deepcopy(conn))

def _connection_probability(self, conn, probability):
def _connection_probability(self, conn, probability, seed=0):
"""Remove/keep a random subset of connections.
Parameters
Expand All @@ -1111,7 +1111,8 @@ def _connection_probability(self, conn, probability):
probability : float
Probability of connection between any src-target pair.
Defaults to 1.0 producing an all-to-all pattern.
seed : int
Seed for the numpy random number generator.
Notes
-----
num_srcs and num_targets are not updated after pruning connections.
Expand All @@ -1122,6 +1123,8 @@ def _connection_probability(self, conn, probability):
this function. As such, this number does not accurately describe the
connections probability of the original set after successive calls.
"""
# Random number generator for random connection selection
rng = np.random.default_rng(seed)
_validate_type(probability, float, 'probability')
if probability <= 0.0 or probability >= 1.0:
raise ValueError('probability must be in the range (0,1)')
Expand All @@ -1133,7 +1136,7 @@ def _connection_probability(self, conn, probability):
len(all_connections) * probability).astype(int)

# Select a random subset of connections to retain.
new_connections = np.random.choice(
new_connections = rng.choice(
range(len(all_connections)), n_connections, replace=False)
remove_srcs = list()
connection_idx = 0
Expand Down

0 comments on commit 07b3990

Please sign in to comment.