From f92e3f2809b4c11767eaecc391a45931b1055a2d Mon Sep 17 00:00:00 2001 From: ShawnDu Date: Thu, 24 Nov 2022 18:53:50 +0800 Subject: [PATCH] add test_optimizer fix lint, add flake8 config file, autopep8 config file --- .flake8 | 15 +++++ .github/workflows/linux_ci.yml | 5 +- .gitignore | 1 + bpl/core/networks.py | 35 +++++----- bpl/core/neurons.py | 42 ++++++------ bpl/core/neurons_base.py | 33 +++++----- bpl/core/runner.py | 31 ++++----- bpl/core/synapses.py | 98 ++++++++++++++-------------- bpl/respa/base.py | 99 ++++++++++++++++------------- bpl/respa/optimizer.py | 50 +++++++-------- bpl/respa/res_manager.py | 15 +++-- bpl/respa/utils.py | 2 +- examples/brain_simulation_multi.py | 27 ++++---- examples/brain_simulation_single.py | 33 +++++----- examples/callback.py | 24 ++++--- examples/test.ipynb | 6 +- pyproject.toml | 4 ++ setup.py | 84 ++++++++++++------------ tests/base.py | 9 +++ tests/test_callback.py | 62 ------------------ tests/test_optimizer.py | 32 ++++++++++ tests/test_respa_base.py | 35 +++++----- 22 files changed, 381 insertions(+), 361 deletions(-) create mode 100644 .flake8 create mode 100644 pyproject.toml create mode 100644 tests/base.py delete mode 100644 tests/test_callback.py create mode 100644 tests/test_optimizer.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..0dae043 --- /dev/null +++ b/.flake8 @@ -0,0 +1,15 @@ +[flake8] +select = B,C,E,F,P,T4,W,B9 +indent-size = 2 +max-line-length = 120 +# C408 ignored because we like the dict keyword argument syntax +# E501 is not flexible enough, we're using B950 instead +ignore = + E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E111, + EXE001, + B007,B008, + C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415, + E731 +optional-ascii-coding = True + +exclude = */__init__.py diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index e98d272..41b76b2 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -29,10 +29,7 @@ jobs: python setup.py install - name: Lint with flake8 run: | - # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --show-source --statistics - name: Test with pytest run: | pytest tests diff --git a/.gitignore b/.gitignore index 2ab793b..95febc3 100644 --- a/.gitignore +++ b/.gitignore @@ -128,3 +128,4 @@ dmypy.json # Pyre type checker .pyre/ .vscode +.idea diff --git a/bpl/core/networks.py b/bpl/core/networks.py index c4c83cb..e5f774a 100644 --- a/bpl/core/networks.py +++ b/bpl/core/networks.py @@ -5,7 +5,8 @@ from .base import RemoteDynamicalSystem from mpi4py import MPI import platform -if platform.system()!='Windows': + +if platform.system() != 'Windows': import mpi4jax import numpy as np @@ -15,19 +16,19 @@ class RemoteNetwork(dyn.Network, RemoteDynamicalSystem): """ def __init__( - self, - *ds_tuple, - comm=MPI.COMM_WORLD, - name: str = None, - mode: Mode = normal, - **ds_dict + self, + *ds_tuple, + comm=MPI.COMM_WORLD, + name: str = None, + mode: Mode = normal, + **ds_dict ): super(RemoteNetwork, self).__init__(*ds_tuple, - name=name, - mode=mode, - **ds_dict) + name=name, + mode=mode, + **ds_dict) self.comm = comm - if self.comm == None: + if self.comm is None: self.rank = None else: self.rank = self.comm.Get_rank() @@ -39,7 +40,7 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None): if nodes is None: nodes = tuple(self.nodes(level=1, include_self=False).subset(dyn.DynamicalSystem).unique().values()) elif isinstance(nodes, dyn.DynamicalSystem): - nodes = (nodes, ) + nodes = (nodes,) elif isinstance(nodes, dict): nodes = tuple(nodes.values()) if not isinstance(nodes, (tuple, list)): @@ -48,14 +49,14 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None): if hasattr(node, 'comm'): for name in node.local_delay_vars: if self.rank == node.source_rank: - if platform.system()=='Windows': + if platform.system() == 'Windows': self.comm.send(len(node.pre.spike), dest=node.target_rank, tag=2) self.comm.Send(node.pre.spike.to_numpy(), dest=node.target_rank, tag=3) else: token = mpi4jax.send(node.pre.spike.value, dest=node.target_rank, tag=3, comm=self.comm) elif self.rank == node.target_rank: delay = self.remote_global_delay_data[name][0] - if platform.system()=='Windows': + if platform.system() == 'Windows': pre_len = self.comm.recv(source=node.source_rank, tag=2) target = np.empty(pre_len, dtype=np.bool_) self.comm.Recv(target, source=node.source_rank, tag=3) @@ -81,14 +82,14 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): if hasattr(node, 'comm'): for name in node.local_delay_vars: if self.rank == node.source_rank: - if platform.system()=='Windows': + if platform.system() == 'Windows': self.comm.send(len(node.pre.spike), dest=node.target_rank, tag=4) self.comm.Send(node.pre.spike.to_numpy(), dest=node.target_rank, tag=5) else: token = mpi4jax.send(node.pre.spike.value, dest=node.target_rank, tag=4, comm=self.comm) elif self.rank == node.target_rank: delay = self.remote_global_delay_data[name][0] - if platform.system()=='Windows': + if platform.system() == 'Windows': pre_len = self.comm.recv(source=node.source_rank, tag=4) target = np.empty(pre_len, dtype=np.bool_) self.comm.Recv(target, source=node.source_rank, tag=5) @@ -100,4 +101,4 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None): for name in node.local_delay_vars: delay = self.global_delay_data[name][0] target = self.global_delay_data[name][1] - delay.reset(target.value) \ No newline at end of file + delay.reset(target.value) diff --git a/bpl/core/neurons.py b/bpl/core/neurons.py index 2abc382..ab9faf1 100644 --- a/bpl/core/neurons.py +++ b/bpl/core/neurons.py @@ -14,31 +14,31 @@ class ProxyLIF(ProxyNeuGroup): """ def __init__( - self, - size: Shape, - keep_size: bool = False, + self, + size: Shape, + keep_size: bool = False, - # other parameter - V_rest: Union[float, Array, Initializer, Callable] = 0., - V_reset: Union[float, Array, Initializer, Callable] = -5., - V_th: Union[float, Array, Initializer, Callable] = 20., - R: Union[float, Array, Initializer, Callable] = 1., - tau: Union[float, Array, Initializer, Callable] = 10., - tau_ref: Optional[Union[float, Array, Initializer, Callable]] = None, - V_initializer: Union[Initializer, Callable, Array] = ZeroInit(), - noise: Optional[Union[float, Array, Initializer, Callable]] = None, - method: str = 'exp_auto', - name: Optional[str] = None, + # other parameter + V_rest: Union[float, Array, Initializer, Callable] = 0., + V_reset: Union[float, Array, Initializer, Callable] = -5., + V_th: Union[float, Array, Initializer, Callable] = 20., + R: Union[float, Array, Initializer, Callable] = 1., + tau: Union[float, Array, Initializer, Callable] = 10., + tau_ref: Optional[Union[float, Array, Initializer, Callable]] = None, + V_initializer: Union[Initializer, Callable, Array] = ZeroInit(), + noise: Optional[Union[float, Array, Initializer, Callable]] = None, + method: str = 'exp_auto', + name: Optional[str] = None, - # training parameter - mode: Mode = normal, - spike_fun: Callable = bm.spike_with_sigmoid_grad, + # training parameter + mode: Mode = normal, + spike_fun: Callable = bm.spike_with_sigmoid_grad, ): # initialization super(ProxyLIF, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) + name=name, + keep_size=keep_size, + mode=mode) check_mode(self.mode, (TrainingMode, NormalMode), self.__class__) # parameters @@ -62,7 +62,7 @@ def __init__( # It is useful for JIT in multi-device enviornment. sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) - + if self.tau_ref is not None: self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, 0, mode) self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), 0, mode) diff --git a/bpl/core/neurons_base.py b/bpl/core/neurons_base.py index 5d49651..bd28b81 100644 --- a/bpl/core/neurons_base.py +++ b/bpl/core/neurons_base.py @@ -12,31 +12,32 @@ class ProxyNeuGroup(dyn.NeuGroup): """ def __init__( - self, - size: Shape, - keep_size: bool = False, - name: str = None, - mode: Mode = normal, + self, + size: Shape, + keep_size: bool = False, + name: str = None, + mode: Mode = normal, ): # initialize super(ProxyNeuGroup, self).__init__(size=size, - name=name, - keep_size=keep_size, - mode=mode) + name=name, + keep_size=keep_size, + mode=mode) def __getitem__(self, item): return ProxyNeuGroupView(target=self, index=item, keep_size=self.keep_size) + class ProxyNeuGroupView(ProxyNeuGroup): """A view for a neuron group instance in multi-device enviornment.""" def __init__( - self, - target: ProxyNeuGroup, - index: Union[slice, Sequence, Array], - name: str = None, - mode: Mode = None, - keep_size: bool = False + self, + target: ProxyNeuGroup, + index: Union[slice, Sequence, Array], + name: str = None, + mode: Mode = None, + keep_size: bool = False ): # check target if not isinstance(target, dyn.DynamicalSystem): @@ -46,7 +47,7 @@ def __init__( if isinstance(index, (int, slice)): index = (index,) self.index = index # the slice - + # check slicing var_shapes = target.varshape if len(self.index) > len(var_shapes): @@ -78,4 +79,4 @@ def __init__( sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode) # Proxy neuron needs 'V' attribute during distributed network updating. - self.V = variable_(bm.zeros, 0, mode) \ No newline at end of file + self.V = variable_(bm.zeros, 0, mode) diff --git a/bpl/core/runner.py b/bpl/core/runner.py index 2d71903..1683553 100644 --- a/bpl/core/runner.py +++ b/bpl/core/runner.py @@ -1,8 +1,4 @@ -import time from typing import Dict, Union, Sequence, Callable - -import numpy as np -import tqdm.auto from jax.tree_util import tree_map from jax.experimental.host_callback import id_tap @@ -14,18 +10,18 @@ class BplRunner(dyn.DSRunner): def __init__( - self, - target: dyn.DynamicalSystem, - - # inputs for target variables - inputs: Sequence = (), - fun_inputs: Callable = None, - - # extra info - dt: float = None, - t0: Union[float, int] = 0., - callback: Callable = None, - **kwargs + self, + target: dyn.DynamicalSystem, + + # inputs for target variables + inputs: Sequence = (), + fun_inputs: Callable = None, + + # extra info + dt: float = None, + t0: Union[float, int] = 0., + callback: Callable = None, + **kwargs ): super(BplRunner, self).__init__( target=target, inputs=inputs, fun_inputs=fun_inputs, @@ -33,7 +29,8 @@ def __init__( self.callback = callback def f_predict(self, shared_args: Dict = None): - if shared_args is None: shared_args = dict() + if shared_args is None: + shared_args = dict() shared_kwargs_str = serialize_kwargs(shared_args) if shared_kwargs_str not in self._f_predict_compiled: diff --git a/bpl/core/synapses.py b/bpl/core/synapses.py index c213e9c..9b37358 100644 --- a/bpl/core/synapses.py +++ b/bpl/core/synapses.py @@ -16,7 +16,8 @@ from mpi4py import MPI import jax.numpy as jnp import platform -if platform.system()!='Windows': + +if platform.system() != 'Windows': import mpi4jax @@ -25,35 +26,35 @@ class RemoteExponential(dyn.synapses.Exponential, RemoteDynamicalSystem): """ def __init__( - self, - source_rank, - pre: dyn.NeuGroup, - target_rank, - post: dyn.NeuGroup, - conn: Union[TwoEndConnector, Array, Dict[str, Array]], - comm=MPI.COMM_WORLD, - output: dyn.SynOut = CUBA(), - stp: Optional[dyn.SynSTP] = None, - comp_method: str = 'sparse', - g_max: Union[float, Array, Initializer, Callable] = 1., - delay_step: Union[int, Array, Initializer, Callable] = None, - tau: Union[float, Array] = 8.0, - method: str = 'exp_auto', - - # other parameters - name: str = None, - mode: Mode = normal, - stop_spike_gradient: bool = False, - + self, + source_rank, + pre: dyn.NeuGroup, + target_rank, + post: dyn.NeuGroup, + conn: Union[TwoEndConnector, Array, Dict[str, Array]], + comm=MPI.COMM_WORLD, + output: dyn.SynOut = CUBA(), + stp: Optional[dyn.SynSTP] = None, + comp_method: str = 'sparse', + g_max: Union[float, Array, Initializer, Callable] = 1., + delay_step: Union[int, Array, Initializer, Callable] = None, + tau: Union[float, Array] = 8.0, + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: Mode = normal, + stop_spike_gradient: bool = False, + ): super(RemoteExponential, self).__init__(pre=pre, - post=post, - conn=conn, - output=output, - stp=stp, - name=name, - mode=mode, - ) + post=post, + conn=conn, + output=output, + stp=stp, + name=name, + mode=mode, + ) # parameters self.stop_spike_gradient = stop_spike_gradient self.comp_method = comp_method @@ -69,10 +70,10 @@ def __init__( self.target_rank = target_rank self.rank = self.comm.Get_rank() if self.rank == source_rank: - #Make sure the same neuron group only deliver its spike one time during one step network simulation + # Make sure the same neuron group only deliver its spike one time during one step network simulation if self.pre.name not in self.remote_first_send_mark: self.remote_first_send_mark.append(self.pre.name) - if platform.system()=='Windows': + if platform.system() == 'Windows': self.comm.send(len(self.pre.spike), dest=target_rank, tag=0) self.comm.Send(self.pre.spike.to_numpy(), dest=target_rank, tag=1) else: @@ -81,10 +82,10 @@ def __init__( elif self.rank == target_rank: # connections and weights self.g_max, self.conn_mask = self.init_weights(g_max, comp_method, sparse_data='csr') - + if self.pre.name not in self.remote_first_send_mark: self.remote_first_send_mark.append(self.pre.name) - if platform.system()=='Windows': + if platform.system() == 'Windows': pre_len = self.comm.recv(source=source_rank, tag=0) pre_spike = np.empty(pre_len, dtype=np.bool_) self.comm.Recv(pre_spike, source=source_rank, tag=1) @@ -97,11 +98,11 @@ def __init__( self.integral = odeint(lambda g, t: -g / self.tau, method=method) def remote_register_delay( - self, - identifier: str, - delay_step: Optional[Union[int, Array, Callable, Initializer]], - delay_target: bm.Variable, - initial_delay_data: Union[Initializer, Callable, Array, float, int, bool] = None, + self, + identifier: str, + delay_step: Optional[Union[int, Array, Callable, Initializer]], + delay_target: bm.Variable, + initial_delay_data: Union[Initializer, Callable, Array, float, int, bool] = None, ): """Register delay variable in multi-device enviornmrnt. """ @@ -174,27 +175,32 @@ def update(self, tdi, pre_spike=None): # update sub-components self.output.update(tdi) - if self.stp is not None: self.stp.update(tdi, pre_spike) + if self.stp is not None: + self.stp.update(tdi, pre_spike) # post values if isinstance(self.conn, All2All): syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) - if self.stp is not None: syn_value = self.stp(syn_value) + if self.stp is not None: + syn_value = self.stp(syn_value) post_vs = self.syn2post_with_all2all(syn_value, self.g_max) elif isinstance(self.conn, One2One): syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) - if self.stp is not None: syn_value = self.stp(syn_value) + if self.stp is not None: + syn_value = self.stp(syn_value) post_vs = self.syn2post_with_one2one(syn_value, self.g_max) else: if self.comp_method == 'sparse': f = lambda s: bm.pre2post_event_sum(s, self.conn_mask, self.post.num, self.g_max) - if isinstance(self.mode, BatchingMode): f = vmap(f) + if isinstance(self.mode, BatchingMode): + f = vmap(f) post_vs = f(pre_spike) # if not isinstance(self.stp, _NullSynSTP): # raise NotImplementedError() else: syn_value = bm.asarray(pre_spike, dtype=bm.dftype()) - if self.stp is not None: syn_value = self.stp(syn_value) + if self.stp is not None: + syn_value = self.stp(syn_value) post_vs = self.syn2post_with_dense(syn_value, self.g_max, self.conn_mask) # updates self.g.value = self.integral(self.g.value, t, dt) + post_vs @@ -202,10 +208,10 @@ def update(self, tdi, pre_spike=None): return self.output(self.g) def remote_get_delay_data( - self, - identifier: str, - delay_step: Optional[Union[int, bm.JaxArray, jnp.DeviceArray]], - *indices: Union[int, slice, bm.JaxArray, jnp.DeviceArray], + self, + identifier: str, + delay_step: Optional[Union[int, bm.JaxArray, jnp.DeviceArray]], + *indices: Union[int, slice, bm.JaxArray, jnp.DeviceArray], ): """Get delay data according to the provided delay steps in multi-device enviornment. """ diff --git a/bpl/respa/base.py b/bpl/respa/base.py index fa64911..0ab192b 100644 --- a/bpl/respa/base.py +++ b/bpl/respa/base.py @@ -5,19 +5,18 @@ from brainpy.connect import TwoEndConnector from typing import Union, Sequence, Callable, Tuple, Dict import jax.tree_util -from brainpy import tools from .res_manager import ResManager import bpl try: from mpi4py import MPI + mpi_size = MPI.COMM_WORLD.Get_size() mpi_rank = MPI.COMM_WORLD.Get_rank() -except: +except ImportError: mpi_size = 1 mpi_rank = 0 - pop_id = 0 @@ -27,14 +26,20 @@ def get_pop_id(): return pop_id +def reset(): + global pop_id + ResManager.clear() + pop_id = 0 + + class BaseNeuron: proxy_neurons = {} def __init__( - self, - shape: Shape, - *args, - **kwargs + self, + shape: Shape, + *args, + **kwargs ): self.args = args self.kwargs = kwargs @@ -131,12 +136,12 @@ def __call__(self, cls): class BaseSynapse: def __init__( - self, - pre: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], - post: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], - conn: Union[TwoEndConnector, Array, Dict[str, Array]], - *args, - **kwargs + self, + pre: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], + post: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], + conn: Union[TwoEndConnector, Array, Dict[str, Array]], + *args, + **kwargs ): self.pre = pre self.post = post @@ -197,7 +202,7 @@ def build(self): if pre_pid == post_pid and pre_pid == mpi_rank: self.lowref = self.model_class( - pre, post, self.conn, *self.args, **self.kwargs) + pre, post, self.conn, *self.args, **self.kwargs) elif pre_pid == mpi_rank: if post not in BaseNeuron.proxy_neurons: tmp_ = bpl.neurons.ProxyLIF(post_shape) @@ -207,7 +212,7 @@ def build(self): if post_slice is not None: tmp_ = tmp_[post_slice] self.lowref = self.model_class_remote( - pre_pid, pre, post_pid, tmp_, conn=self.conn, *self.args, **self.kwargs) + pre_pid, pre, post_pid, tmp_, conn=self.conn, *self.args, **self.kwargs) elif post_pid == mpi_rank: if pre not in BaseNeuron.proxy_neurons: tmp_ = bpl.neurons.ProxyLIF(pre_shape) @@ -217,32 +222,33 @@ def build(self): if pre_slice is not None: tmp_ = tmp_[pre_slice] self.lowref = self.model_class_remote( - pre_pid, tmp_, post_pid, post, conn=self.conn, *self.args, **self.kwargs) + pre_pid, tmp_, post_pid, post, conn=self.conn, *self.args, **self.kwargs) return self.lowref class LIF(BaseNeuron): def __init__( - self, - shape: Shape, - *args, - **kwargs + self, + shape: Shape, + *args, + **kwargs ): super(LIF, self).__init__(shape, *args, **kwargs) self.model_class = dyn.LIF + # another way to define respa LIF # BaseNeuron.register(dyn.LIF) class Exponential(BaseSynapse): def __init__( - self, - pre: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], - post: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], - conn: Union[TwoEndConnector, Array, Dict[str, Array]], - *args, - **kwargs + self, + pre: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], + post: Union[BaseNeuron, Tuple[BaseNeuron, Union[slice, Sequence, Array]]], + conn: Union[TwoEndConnector, Array, Dict[str, Array]], + *args, + **kwargs ): super().__init__(pre, post, conn, *args, **kwargs) self.model_class = dyn.synapses.Exponential @@ -264,9 +270,9 @@ def build_all_population_synapse(self): for syn in ResManager.syns: syn.build() self.lowref.register_implicit_nodes( - *map(lambda x: x.lowref, ResManager.pops)) + *map(lambda x: x.lowref, ResManager.pops)) self.lowref.register_implicit_nodes( - *map(lambda x: x.lowref, ResManager.syns)) + *map(lambda x: x.lowref, ResManager.syns)) def build(self): self.pops_ = [] @@ -277,15 +283,17 @@ def reg_pop_syn(v): self.pops_.append(v) elif isinstance(v, BaseSynapse): self.syns_.append(v) + jax.tree_util.tree_map(reg_pop_syn, self.__dict__) def simple_split(pops_): res = [[] for i in range(mpi_size)] avg = len(pops_) // mpi_size for i in range(mpi_size): - res[i].extend(pops_[i*avg:i*avg+avg]) - res[-1].extend(pops_[avg*mpi_size:]) + res[i].extend(pops_[i * avg:i * avg + avg]) + res[-1].extend(pops_[avg * mpi_size:]) return res + self.pops_by_rank = simple_split(self.pops_) offset = 1 for i in range(mpi_size): @@ -314,17 +322,17 @@ def register_vars(self, *variables, **named_variables): class DSRunner: def __init__( - self, - target: Union[dyn.DynamicalSystem, Network], - # inputs for target variables - inputs: Sequence = (), - fun_inputs: Callable = None, - # extra info - dt: float = None, - t0: Union[float, int] = 0., - spike_callback: Callable = None, - volt_callback: Callable = None, - **kwargs + self, + target: Union[dyn.DynamicalSystem, Network], + # inputs for target variables + inputs: Sequence = (), + fun_inputs: Callable = None, + # extra info + dt: float = None, + t0: Union[float, int] = 0., + spike_callback: Callable = None, + volt_callback: Callable = None, + **kwargs ): if not isinstance(target, (Network, dyn.DynamicalSystem)): raise ValueError(type(target)) @@ -337,7 +345,7 @@ def _callback(t: float, d: dict): if k == 'spike' and spike_callback: tmp = '' for i, j in enumerate(v): - if j == True: + if j is True: tmp += '{},{:.2f}\n'.format(i + 1, t) spike_callback(tmp) if k == 'V' and volt_callback: @@ -345,12 +353,13 @@ def _callback(t: float, d: dict): for i, j in enumerate(v): tmp += '{},{:.2f},{:.2f}\n'.format(i + 1, t, j) volt_callback(tmp) + c = _callback if spike_callback or volt_callback else None self.lowref = bpl.BplRunner( - target=target, inputs=inputs, - fun_inputs=fun_inputs, dt=dt, - t0=t0, callback=c, **kwargs) + target=target, inputs=inputs, + fun_inputs=fun_inputs, dt=dt, + t0=t0, callback=c, **kwargs) def __getattr__(self, __name: str): return self.lowref.__getattribute__(__name) diff --git a/bpl/respa/optimizer.py b/bpl/respa/optimizer.py index 9d8c012..d3ee781 100644 --- a/bpl/respa/optimizer.py +++ b/bpl/respa/optimizer.py @@ -1,13 +1,12 @@ import numpy as np -from brainpy.connect import FixedProb -from typing import list +from brainpy.connect.random_conn import FixedProb +from typing import List, Tuple from .res_manager import ResManager from .base import mpi_size class Optimizer(): - OPT_GREEDY_INIT = 1 def __init__(self, opt_method=OPT_GREEDY_INIT, device_memory=40, device_capability=60) -> None: @@ -19,19 +18,19 @@ def run(self): if self.opt_method == self.OPT_GREEDY_INIT: self.run_greedy_init() - def get_edge_weight_matrix(self, syns: list, total_pop_num: int) -> list[[int]]: + def get_edge_weight_matrix(self, syns=[], total_pop_num=0) -> List[List[int]]: """get edge weight matrix, traffic matrix between populations. Parameteresult ---------- - syns : list - synapse list, `ResManager.syns` + syns : List + synapse List, `ResManager.syns` total_pop_num : int population total number Returns ------- - list[[int]] + List two dimensional np array """ edge_weight_matrix = np.zeros((total_pop_num, total_pop_num)) @@ -46,16 +45,16 @@ def get_edge_weight_matrix(self, syns: list, total_pop_num: int) -> list[[int]]: if isinstance(conn, FixedProb): edge_weight = pre_num * post_num * conn.prob - edge_weight = edge_weight_matrix[pre_id][post_id] - edge_weight_matrix[pre_id][post_id] += edge_weight + last_edge_weight = edge_weight_matrix[pre_id - 1][post_id - 1] + edge_weight_matrix[pre_id - 1][post_id - 1] = last_edge_weight + edge_weight return edge_weight_matrix - def prepare_input(self) -> tuple(list[[int]], int, int): + def prepare_input(self) -> Tuple[List[List[int]], int, int]: """ prepare optimizer input Returns ------- - tuple(list[[int]], int, int) + tuple(List[[int]], int, int) tuple(edge_weight_matrix, memory_used, memory_capacity) """ @@ -68,13 +67,13 @@ def prepare_input(self) -> tuple(list[[int]], int, int): sum_hardware_store[1] = sum_hardware_store[1] + self.device_capability edge_weight_matrix = self.get_edge_weight_matrix( - ResManager.syns, total_num=total_num) + ResManager.syns, total_num=total_num) memory_used = int((edge_weight_matrix.sum(0) * 10 + pop_neuron_count * 300) / (1024 ** 3)) memory_capacity = int(sum_hardware_store[0]) - return edge_weight_matrix, memory_used, memory_capacity + return (edge_weight_matrix, memory_used, memory_capacity) def run_greedy_init(self): """Greedy initialization algorithm. Each time find the method that can reduce @@ -84,14 +83,14 @@ def run_greedy_init(self): A, memory_used, memory_capacity = self.prepare_input() population_num = A.shape[0] assert A.shape == (population_num, population_num) - assert memory_used.shape == (population_num, ) + assert memory_used.shape == (population_num,) device_num = len(memory_capacity) - assert memory_capacity.shape == (device_num, ) + assert memory_capacity.shape == (device_num,) assert np.all(A >= 0) capacity_each = memory_used.sum() / device_num memory_capacity[np.where(memory_capacity > capacity_each)] = capacity_each * 1.01 - + # each population assign to one device result = {} # memory remaining for each device @@ -118,30 +117,27 @@ def run_greedy_init(self): result[device_index] = r remote_conn_pre = remote_conn - remote_conn = remote_conn + \ - A[population_index, :].sum() - A[population_index, result.get(device_index)].sum() + remote_conn = remote_conn + A[population_index, :].sum() - A[population_index, result.get(device_index)].sum() new_capacity = device_available[device_index] - memory_used[population_index] - ( - remote_conn - remote_conn_pre) * 280 / 1024 ** 3 + remote_conn - remote_conn_pre) * 280 / 1024 ** 3 if new_capacity >= 0: device_available[device_index] = new_capacity - move_gain[device_index] += A[:, population_index] + \ - A[population_index, :] + move_gain[device_index] += A[:, population_index] + A[population_index, :] move_gain[device_index, memory_used > new_capacity] = -float('inf') move_gain[:, population_index] = -float('inf') else: move_gain[device_index, :] = -float('inf') - if device_index+1 > device_num-1: + if device_index + 1 > device_num - 1: print("resource not enough") break - device_available[device_index+1] -= memory_used[population_index] - new_capacity = device_available[device_index+1] + device_available[device_index + 1] -= memory_used[population_index] + new_capacity = device_available[device_index + 1] remote_conn = 0 - move_gain[device_index + 1] += A[:, - population_index] + A[population_index, :] - move_gain[device_index+1, memory_used > new_capacity] = -float('inf') + move_gain[device_index + 1] += A[:, population_index] + A[population_index, :] + move_gain[device_index + 1, memory_used > new_capacity] = -float('inf') move_gain[:, population_index] = -float('inf') r = result.get(device_index + 1, []) r.append(population_index) diff --git a/bpl/respa/res_manager.py b/bpl/respa/res_manager.py index 436c654..65451a6 100644 --- a/bpl/respa/res_manager.py +++ b/bpl/respa/res_manager.py @@ -1,14 +1,19 @@ - class ResManager(): - # [BaseNeuron(), ] pops = [] - + # {id: BaseNeuron()}, id is BaseNeuron's id pops_by_id = {} - + # [BaseSynapse(), ] syns = [] - + # element is an array, like [BaseNeuron(), ], pops_by_rank can be expressed as: {device_index: [BaseNeuron(), ] } pops_by_rank = {} + + @classmethod + def clear(cls): + cls.pops = [] + cls.pops_by_id = {} + cls.syns = [] + cls.pops_by_rank = {} diff --git a/bpl/respa/utils.py b/bpl/respa/utils.py index 48e3976..20f44c5 100644 --- a/bpl/respa/utils.py +++ b/bpl/respa/utils.py @@ -20,7 +20,7 @@ def input_transform(pops: Sequence): for pop in pops: try: tmp = pop[0].input - input_trans.append((tmp, pop[1])+pop[2:]) + input_trans.append((tmp, pop[1]) + pop[2:]) except Exception as e: continue return input_trans diff --git a/examples/brain_simulation_multi.py b/examples/brain_simulation_multi.py index 5eb0b40..6087650 100644 --- a/examples/brain_simulation_multi.py +++ b/examples/brain_simulation_multi.py @@ -1,8 +1,9 @@ import sys + sys.path.append('../') import bpl import brainpy as bp -import os +import os os.environ['CUDA_VISIBLE_DEVICES'] = "1" @@ -16,36 +17,37 @@ def __init__(self, scale=1.0, method='exp_auto'): # network size num_exc = int(3200 * scale) num_inh = int(800 * scale) - + # neurons pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) # synapses we = 0.6 / scale # excitatory synaptic weight (voltage) wi = 6.7 / scale # inhibitory synaptic weight - + if self.rank == 0: self.E1 = bp.neurons.LIF(num_exc, **pars, method=method) self.I1 = bp.neurons.LIF(num_inh, **pars, method=method) - self.E12I1 = bp.synapses.Exponential(self.E1, self.I1, - bp.conn.FixedProb(0.02,seed=1), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., - method=method, - delay_step=1) + self.E12I1 = bp.synapses.Exponential(self.E1, self.I1, + bp.conn.FixedProb(0.02, seed=1), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., + method=method, + delay_step=1) self.I2 = bpl.neurons.ProxyLIF(num_inh, **pars, method=method) elif self.rank == 1: self.E1 = bpl.neurons.ProxyLIF(num_exc, **pars, method=method) self.I1 = bpl.neurons.ProxyLIF(num_inh, **pars, method=method) self.I2 = bp.neurons.LIF(num_inh, **pars, method=method) - self.remoteE12I2 = bpl.synapses.RemoteExponential(0, self.E1, 1, self.I2, - bp.conn.FixedProb(0.02,seed=1), + self.remoteE12I2 = bpl.synapses.RemoteExponential(0, self.E1, 1, self.I2, + bp.conn.FixedProb(0.02, seed=1), output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., + tau=5., method=method, delay_step=1 ) + def run_model_v1(): net = EINet_V1(scale=1., method='exp_auto') runner = bp.dyn.DSRunner( @@ -63,4 +65,3 @@ def run_model_v1(): if __name__ == '__main__': run_model_v1() - \ No newline at end of file diff --git a/examples/brain_simulation_single.py b/examples/brain_simulation_single.py index d78530c..bb20d2b 100644 --- a/examples/brain_simulation_single.py +++ b/examples/brain_simulation_single.py @@ -1,4 +1,5 @@ import sys + sys.path.append('../') import brainpy as bp @@ -12,30 +13,31 @@ def __init__(self, scale=1.0, method='exp_auto'): # network size num_exc = int(3200 * scale) num_inh = int(800 * scale) - + # neurons pars = dict(V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5.) # synapses we = 0.6 / scale # excitatory synaptic weight (voltage) wi = 6.7 / scale # inhibitory synaptic weight - + self.E1 = bp.neurons.LIF(num_exc, **pars, method=method) self.I1 = bp.neurons.LIF(num_inh, **pars, method=method) - self.E12I1 = bp.synapses.Exponential(self.E1, self.I1, - bp.conn.FixedProb(0.02,seed=1), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., - method=method, - delay_step=1) + self.E12I1 = bp.synapses.Exponential(self.E1, self.I1, + bp.conn.FixedProb(0.02, seed=1), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., + method=method, + delay_step=1) self.I2 = bp.neurons.LIF(num_inh, **pars, method=method) - self.E12I2 = bp.synapses.Exponential(self.E1, self.I2, - bp.conn.FixedProb(0.02,seed=1), - output=bp.synouts.COBA(E=0.), g_max=we, - tau=5., - method=method, - delay_step=1 - ) + self.E12I2 = bp.synapses.Exponential(self.E1, self.I2, + bp.conn.FixedProb(0.02, seed=1), + output=bp.synouts.COBA(E=0.), g_max=we, + tau=5., + method=method, + delay_step=1 + ) + def run_model_v1(): net = EINet_V1(scale=1., method='exp_auto') @@ -53,4 +55,3 @@ def run_model_v1(): if __name__ == '__main__': run_model_v1() - \ No newline at end of file diff --git a/examples/callback.py b/examples/callback.py index 707a2c4..a91b7f5 100644 --- a/examples/callback.py +++ b/examples/callback.py @@ -1,17 +1,19 @@ import bpl import brainpy as bp + class MyNetwork(bpl.Network): def __init__(self, *ds_tuple): super(MyNetwork, self).__init__(ds_tuple) self.a = bpl.LIF(20, V_rest=-60., V_th=-50., V_reset=-60., tau=20., - tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) + tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) self.b = bpl.LIF(10, V_rest=-60., V_th=-50., V_reset=-60., tau=20., - tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) + tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) # self.c = bpl.Exponential(ds_tuple[0], self.a, bp.conn.FixedProb( # 0.02), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto') self.d = bpl.Exponential(self.a[100:], self.b, bp.conn.FixedProb( - 0.2, seed=123), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + 0.2, seed=123), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + net = MyNetwork() net.build() @@ -21,19 +23,23 @@ def __init__(self, *ds_tuple): monitors = {} monitors.update(monitor_spike) monitors.update(monitor_volt) + + def spike(s: str): print(s) + def volt(s: str): print(s) + runner = bpl.DSRunner( - net, - monitors=monitors, - inputs=inputs, - jit=False, - spike_callback=spike, - volt_callback=volt, + net, + monitors=monitors, + inputs=inputs, + jit=False, + spike_callback=spike, + volt_callback=volt, ) runner.run(10.) diff --git a/examples/test.ipynb b/examples/test.ipynb index 258f063..caa0ac1 100644 --- a/examples/test.ipynb +++ b/examples/test.ipynb @@ -149,7 +149,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.7 ('base')", + "display_name": "Python 3.8.10 64-bit", "language": "python", "name": "python3" }, @@ -163,12 +163,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.7" + "version": "3.8.10" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "08d23fa5f426e3cac3d452d649c8af7e8e0cfeeecb0333911b5f2b6d13f711fa" + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" } } }, diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..1d64e18 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,4 @@ +[tool.autopep8] +indent-size = 2 +max_line_length = 120 +ignore = "E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E111, EXE001, B007,B008, C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415, E731" \ No newline at end of file diff --git a/setup.py b/setup.py index 9f5bd2d..2bfb9bb 100644 --- a/setup.py +++ b/setup.py @@ -26,46 +26,46 @@ # setup setup( - name='brainpy-largescale', - version=version, - description='brainpy-largescale depends on brainpy', - long_description=README, - long_description_content_type="text/markdown", - author='NanHu Neuromorphic Computing Laboratory Team', - author_email='nhnao@cnaeit.com', - packages=packages, - python_requires='>=3.7', - install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm', 'brainpy', - 'brainpylib', 'numba', 'mpi4py', 'mpi4jax', 'jax[cpu]==0.3.24'], - url='https://github.com/NH-NCL/brainpy-largescale', - project_urls={ - "Bug Tracker": "https://github.com/NH-NCL/brainpy-largescale/issues", - "Documentation": "https://brainpy.readthedocs.io/", - "Source Code": "https://github.com/NH-NCL/brainpy-largescale", - }, - keywords=('brainpy largescale, ' - 'computational neuroscience, ' - 'brain-inspired computation, ' - 'dynamical systems, ' - 'differential equations, ' - 'brain modeling, ' - 'brain dynamics modeling, ' - 'brain dynamics programming'), - classifiers=[ - 'Natural Language :: English', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Intended Audience :: Science/Research', - 'License :: OSI Approved :: Apache Software License', - 'Topic :: Scientific/Engineering :: Bio-Informatics', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development :: Libraries', - ], - license='Apache-2.0 license', + name='brainpy-largescale', + version=version, + description='brainpy-largescale depends on brainpy', + long_description=README, + long_description_content_type="text/markdown", + author='NanHu Neuromorphic Computing Laboratory Team', + author_email='nhnao@cnaeit.com', + packages=packages, + python_requires='>=3.7', + install_requires=['numpy>=1.15', 'jax>=0.3.0', 'tqdm', 'brainpy', + 'brainpylib', 'numba', 'mpi4py', 'mpi4jax', 'jax[cpu]==0.3.24'], + url='https://github.com/NH-NCL/brainpy-largescale', + project_urls={ + "Bug Tracker": "https://github.com/NH-NCL/brainpy-largescale/issues", + "Documentation": "https://brainpy.readthedocs.io/", + "Source Code": "https://github.com/NH-NCL/brainpy-largescale", + }, + keywords=('brainpy largescale, ' + 'computational neuroscience, ' + 'brain-inspired computation, ' + 'dynamical systems, ' + 'differential equations, ' + 'brain modeling, ' + 'brain dynamics modeling, ' + 'brain dynamics programming'), + classifiers=[ + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Topic :: Scientific/Engineering :: Bio-Informatics', + 'Topic :: Scientific/Engineering :: Mathematics', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development :: Libraries', + ], + license='Apache-2.0 license', ) diff --git a/tests/base.py b/tests/base.py new file mode 100644 index 0000000..f8f6771 --- /dev/null +++ b/tests/base.py @@ -0,0 +1,9 @@ +import unittest +import bpl + + +class BaseTest(unittest.TestCase): + + def setUp(self) -> None: + bpl.ResManager.clear() + return super().setUp() diff --git a/tests/test_callback.py b/tests/test_callback.py deleted file mode 100644 index ec89eef..0000000 --- a/tests/test_callback.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest -import bpl -import brainpy as bp - - -class BaseFunctionsTestCase(unittest.TestCase): - def testbasefunc(self): - class MyNetwork(bpl.Network): - def __init__(self, *ds_tuple): - super(MyNetwork, self).__init__(ds_tuple) - self.a = bpl.LIF(20, V_rest=-60., V_th=-50., V_reset=-60., tau=20., - tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) - self.b = bpl.LIF(10, V_rest=-60., V_th=-50., V_reset=-60., tau=20., - tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) - # self.c = bpl.Exponential(ds_tuple[0], self.a, bp.conn.FixedProb( - # 0.02), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto') - self.d = bpl.Exponential(self.a[100:], self.b, bp.conn.FixedProb( - 0.2, seed=123), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) - - net = MyNetwork() - net.build() - # from mpi4py import MPI - # if MPI.COMM_WORLD.Get_size() == 2: - # if MPI.COMM_WORLD.Get_rank() == 1: - # monitors = {'spikes': net.b.spike} - # # inputs = [] - # else: - # monitors = {} - # # inputs = [(net.a.input, 20.)] - # else: - # monitors = {'spikes': net.b.spike} - # inputs = [(net.a.input, 20.)] - inputs = bpl.input_transform([(net.a, 20)]) - monitor_spike = bpl.monitor_transform([net.a], attr='spike') - monitor_volt = bpl.monitor_transform([net.a], attr='V') - monitors = {} - monitors.update(monitor_spike) - monitors.update(monitor_volt) - def spike(s: str): - print(s) - - def volt(s: str): - print(s) - - runner = bpl.DSRunner( - net, - monitors=monitors, - inputs=inputs, - jit=False, - spike_callback=spike, - volt_callback=volt, - ) - runner.run(10.) - - if 'spike' in runner.mon: - bp.visualize.raster_plot(runner.mon.ts, runner.mon['spike'], show=True) - elif 'V' in runner.mon: - bp.visualize.raster_plot(runner.mon.ts, runner.mon['V'], show=True) - - -# if __name__ == '__main__': -# unittest.main() diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..5a5393f --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,32 @@ +from bpl.respa.optimizer import Optimizer +from bpl.respa.res_manager import ResManager +import bpl +import brainpy as bp +import numpy as np +from .base import BaseTest + + +class TestOptimizer(BaseTest): + def test_get_edge_weight_matrix(self): + a = bpl.LIF(2, V_rest=-60., V_th=-50., V_reset=-60., tau=20., + tau_ref=5., method='exp_auto') + b = bpl.LIF(3, V_rest=-60., V_th=-50., V_reset=-60., tau=20., + tau_ref=5., method='exp_auto') + c = bpl.LIF(4, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto') + d = bpl.LIF(5, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., method='exp_auto') + bpl.Exponential(a, b, bp.conn.FixedProb(1, seed=123), g_max=10, tau=5., + output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + bpl.Exponential(a, c, bp.conn.FixedProb(1, seed=123), g_max=10, tau=5., + output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + bpl.Exponential(b, c, bp.conn.FixedProb(1, seed=123), g_max=10, tau=5., + output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + bpl.Exponential(b, d, bp.conn.FixedProb(1, seed=123), g_max=10, tau=5., + output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + bpl.Exponential(d, a, bp.conn.FixedProb(1, seed=123), g_max=10, tau=5., + output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + opt = Optimizer() + matrix = opt.get_edge_weight_matrix(ResManager.syns, total_pop_num=len(ResManager.pops)) + self.assertTrue(np.array_equal(matrix, [[0., 6., 8., 0.], + [0., 0., 12., 15.], + [0., 0., 0., 0.], + [10., 0., 0., 0.]])) diff --git a/tests/test_respa_base.py b/tests/test_respa_base.py index 4241c3a..889c129 100644 --- a/tests/test_respa_base.py +++ b/tests/test_respa_base.py @@ -3,9 +3,10 @@ import brainpy as bp from brainpy.dyn import channels, synouts import brainpy.math as bm +from .base import BaseTest -class BaseFunctionsTestCase(unittest.TestCase): +class BaseFunctionsTestCase(BaseTest): def testbasefunc(self): class MyNetwork(bpl.Network): def __init__(self, *ds_tuple): @@ -16,8 +17,8 @@ def __init__(self, *ds_tuple): tau_ref=5., method='exp_auto', V_initializer=bp.initialize.Normal(-55., 2.)) # self.c = bpl.Exponential(ds_tuple[0], self.a, bp.conn.FixedProb( # 0.02), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto') - self.d = bpl.Exponential(self.a[100:], self.b, bp.conn.FixedProb( - 0.02, seed=123), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) + self.d = bpl.Exponential(self.a, self.b, bp.conn.FixedProb( + 0.02, seed=123), g_max=10, tau=5., output=bp.synouts.COBA(E=0.), method='exp_auto', delay_step=1) net = MyNetwork() net.build() @@ -35,19 +36,19 @@ def __init__(self, *ds_tuple): inputs = bpl.input_transform([(net.a, 20)]) monitors = bpl.monitor_transform([net.b]) runner = bpl.DSRunner( - net, - monitors=monitors, - inputs=inputs, - jit=False + net, + monitors=monitors, + inputs=inputs, + jit=False ) - runner.run(10.) - if 'spike' in runner.mon: - bp.visualize.raster_plot( - runner.mon.ts, runner.mon['spike'], show=True) - print(net.pops_) - print(net.pops_by_rank) - print(net.syns_) - # print(net.nodes()) + runner.run(5.) + # if 'spike' in runner.mon: + # bp.visualize.raster_plot( + # runner.mon.ts, runner.mon['spike'], show=True) + # print(net.pops_) + # print(net.pops_by_rank) + # print(net.syns_) + # print(net.nodes()) def testBaseNeuronregister(self): @bpl.register() @@ -106,8 +107,8 @@ def run_ei_v1(): net = EINet_v1(scale=1) net.build() runner = bpl.DSRunner(net, monitors={'E.spike': net.E.spike}) - runner.run(100.) - bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + runner.run(5.) + # bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) run_ei_v1()