diff --git a/hnn_core/network.py b/hnn_core/network.py index 6f675c8ff..4e1efcf62 100644 --- a/hnn_core/network.py +++ b/hnn_core/network.py @@ -958,7 +958,7 @@ def _all_to_all_connect(self, src_cell, target_cell, nc_dict['A_weight'], nc_dict['A_delay'], nc_dict['lamtha']) def add_connection(self, src_gids, target_gids, loc, receptor, - weight, delay, lamtha, probability=1.0): + weight, delay, lamtha, probability=1.0, seed=0): """Appends connections to connectivity list Parameters @@ -1094,13 +1094,13 @@ def add_connection(self, src_gids, target_gids, loc, receptor, # Probabilistically define connections if probability != 1.0: - self._connection_probability(conn, probability) + self._connection_probability(conn, probability, seed) conn['probability'] = probability self.connectivity.append(deepcopy(conn)) - def _connection_probability(self, conn, probability): + def _connection_probability(self, conn, probability, seed=0): """Remove/keep a random subset of connections. Parameters @@ -1111,7 +1111,8 @@ def _connection_probability(self, conn, probability): probability : float Probability of connection between any src-target pair. Defaults to 1.0 producing an all-to-all pattern. - + seed : int + Seed for the numpy random number generator. Notes ----- num_srcs and num_targets are not updated after pruning connections. @@ -1122,6 +1123,8 @@ def _connection_probability(self, conn, probability): this function. As such, this number does not accurately describe the connections probability of the original set after successive calls. """ + # Random number generator for random connection selection + rng = np.random.default_rng(seed) _validate_type(probability, float, 'probability') if probability <= 0.0 or probability >= 1.0: raise ValueError('probability must be in the range (0,1)') @@ -1133,7 +1136,7 @@ def _connection_probability(self, conn, probability): len(all_connections) * probability).astype(int) # Select a random subset of connections to retain. - new_connections = np.random.choice( + new_connections = rng.choice( range(len(all_connections)), n_connections, replace=False) remove_srcs = list() connection_idx = 0