Skip to content

Commit

Permalink
add test_optimizer
Browse files Browse the repository at this point in the history
fix lint, add flake8 config file, autopep8 config file
  • Loading branch information
dulingkang committed Nov 25, 2022
1 parent 9d3ce59 commit f92e3f2
Show file tree
Hide file tree
Showing 22 changed files with 381 additions and 361 deletions.
15 changes: 15 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[flake8]
select = B,C,E,F,P,T4,W,B9
indent-size = 2
max-line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
ignore =
E203,E305,E402,E501,E721,E741,F405,F821,F841,F999,W503,W504,C408,E302,W291,E303,E111,
EXE001,
B007,B008,
C400,C401,C402,C403,C404,C405,C407,C411,C413,C414,C415,
E731
optional-ascii-coding = True

exclude = */__init__.py
5 changes: 1 addition & 4 deletions .github/workflows/linux_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ jobs:
python setup.py install
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
flake8 . --show-source --statistics
- name: Test with pytest
run: |
pytest tests
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,4 @@ dmypy.json
# Pyre type checker
.pyre/
.vscode
.idea
35 changes: 18 additions & 17 deletions bpl/core/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .base import RemoteDynamicalSystem
from mpi4py import MPI
import platform
if platform.system()!='Windows':

if platform.system() != 'Windows':
import mpi4jax
import numpy as np

Expand All @@ -15,19 +16,19 @@ class RemoteNetwork(dyn.Network, RemoteDynamicalSystem):
"""

def __init__(
self,
*ds_tuple,
comm=MPI.COMM_WORLD,
name: str = None,
mode: Mode = normal,
**ds_dict
self,
*ds_tuple,
comm=MPI.COMM_WORLD,
name: str = None,
mode: Mode = normal,
**ds_dict
):
super(RemoteNetwork, self).__init__(*ds_tuple,
name=name,
mode=mode,
**ds_dict)
name=name,
mode=mode,
**ds_dict)
self.comm = comm
if self.comm == None:
if self.comm is None:
self.rank = None
else:
self.rank = self.comm.Get_rank()
Expand All @@ -39,7 +40,7 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None):
if nodes is None:
nodes = tuple(self.nodes(level=1, include_self=False).subset(dyn.DynamicalSystem).unique().values())
elif isinstance(nodes, dyn.DynamicalSystem):
nodes = (nodes, )
nodes = (nodes,)
elif isinstance(nodes, dict):
nodes = tuple(nodes.values())
if not isinstance(nodes, (tuple, list)):
Expand All @@ -48,14 +49,14 @@ def update_local_delays(self, nodes: Union[Sequence, Dict] = None):
if hasattr(node, 'comm'):
for name in node.local_delay_vars:
if self.rank == node.source_rank:
if platform.system()=='Windows':
if platform.system() == 'Windows':
self.comm.send(len(node.pre.spike), dest=node.target_rank, tag=2)
self.comm.Send(node.pre.spike.to_numpy(), dest=node.target_rank, tag=3)
else:
token = mpi4jax.send(node.pre.spike.value, dest=node.target_rank, tag=3, comm=self.comm)
elif self.rank == node.target_rank:
delay = self.remote_global_delay_data[name][0]
if platform.system()=='Windows':
if platform.system() == 'Windows':
pre_len = self.comm.recv(source=node.source_rank, tag=2)
target = np.empty(pre_len, dtype=np.bool_)
self.comm.Recv(target, source=node.source_rank, tag=3)
Expand All @@ -81,14 +82,14 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
if hasattr(node, 'comm'):
for name in node.local_delay_vars:
if self.rank == node.source_rank:
if platform.system()=='Windows':
if platform.system() == 'Windows':
self.comm.send(len(node.pre.spike), dest=node.target_rank, tag=4)
self.comm.Send(node.pre.spike.to_numpy(), dest=node.target_rank, tag=5)
else:
token = mpi4jax.send(node.pre.spike.value, dest=node.target_rank, tag=4, comm=self.comm)
elif self.rank == node.target_rank:
delay = self.remote_global_delay_data[name][0]
if platform.system()=='Windows':
if platform.system() == 'Windows':
pre_len = self.comm.recv(source=node.source_rank, tag=4)
target = np.empty(pre_len, dtype=np.bool_)
self.comm.Recv(target, source=node.source_rank, tag=5)
Expand All @@ -100,4 +101,4 @@ def reset_local_delays(self, nodes: Union[Sequence, Dict] = None):
for name in node.local_delay_vars:
delay = self.global_delay_data[name][0]
target = self.global_delay_data[name][1]
delay.reset(target.value)
delay.reset(target.value)
42 changes: 21 additions & 21 deletions bpl/core/neurons.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,31 @@ class ProxyLIF(ProxyNeuGroup):
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
self,
size: Shape,
keep_size: bool = False,

# other parameter
V_rest: Union[float, Array, Initializer, Callable] = 0.,
V_reset: Union[float, Array, Initializer, Callable] = -5.,
V_th: Union[float, Array, Initializer, Callable] = 20.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Optional[Union[float, Array, Initializer, Callable]] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Optional[Union[float, Array, Initializer, Callable]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,
# other parameter
V_rest: Union[float, Array, Initializer, Callable] = 0.,
V_reset: Union[float, Array, Initializer, Callable] = -5.,
V_th: Union[float, Array, Initializer, Callable] = 20.,
R: Union[float, Array, Initializer, Callable] = 1.,
tau: Union[float, Array, Initializer, Callable] = 10.,
tau_ref: Optional[Union[float, Array, Initializer, Callable]] = None,
V_initializer: Union[Initializer, Callable, Array] = ZeroInit(),
noise: Optional[Union[float, Array, Initializer, Callable]] = None,
method: str = 'exp_auto',
name: Optional[str] = None,

# training parameter
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
# training parameter
mode: Mode = normal,
spike_fun: Callable = bm.spike_with_sigmoid_grad,
):
# initialization
super(ProxyLIF, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)
name=name,
keep_size=keep_size,
mode=mode)
check_mode(self.mode, (TrainingMode, NormalMode), self.__class__)

# parameters
Expand All @@ -62,7 +62,7 @@ def __init__(
# It is useful for JIT in multi-device enviornment.
sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)

if self.tau_ref is not None:
self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, 0, mode)
self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), 0, mode)
Expand Down
33 changes: 17 additions & 16 deletions bpl/core/neurons_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,31 +12,32 @@ class ProxyNeuGroup(dyn.NeuGroup):
"""

def __init__(
self,
size: Shape,
keep_size: bool = False,
name: str = None,
mode: Mode = normal,
self,
size: Shape,
keep_size: bool = False,
name: str = None,
mode: Mode = normal,
):
# initialize
super(ProxyNeuGroup, self).__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)
name=name,
keep_size=keep_size,
mode=mode)

def __getitem__(self, item):
return ProxyNeuGroupView(target=self, index=item, keep_size=self.keep_size)


class ProxyNeuGroupView(ProxyNeuGroup):
"""A view for a neuron group instance in multi-device enviornment."""

def __init__(
self,
target: ProxyNeuGroup,
index: Union[slice, Sequence, Array],
name: str = None,
mode: Mode = None,
keep_size: bool = False
self,
target: ProxyNeuGroup,
index: Union[slice, Sequence, Array],
name: str = None,
mode: Mode = None,
keep_size: bool = False
):
# check target
if not isinstance(target, dyn.DynamicalSystem):
Expand All @@ -46,7 +47,7 @@ def __init__(
if isinstance(index, (int, slice)):
index = (index,)
self.index = index # the slice

# check slicing
var_shapes = target.varshape
if len(self.index) > len(var_shapes):
Expand Down Expand Up @@ -78,4 +79,4 @@ def __init__(
sp_type = bm.dftype() if isinstance(mode, TrainingMode) else bool # the gradient of spike is a float
self.spike = variable_(lambda s: bm.zeros(s, dtype=sp_type), self.varshape, mode)
# Proxy neuron needs 'V' attribute during distributed network updating.
self.V = variable_(bm.zeros, 0, mode)
self.V = variable_(bm.zeros, 0, mode)
31 changes: 14 additions & 17 deletions bpl/core/runner.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import time
from typing import Dict, Union, Sequence, Callable

import numpy as np
import tqdm.auto
from jax.tree_util import tree_map
from jax.experimental.host_callback import id_tap

Expand All @@ -14,26 +10,27 @@

class BplRunner(dyn.DSRunner):
def __init__(
self,
target: dyn.DynamicalSystem,

# inputs for target variables
inputs: Sequence = (),
fun_inputs: Callable = None,

# extra info
dt: float = None,
t0: Union[float, int] = 0.,
callback: Callable = None,
**kwargs
self,
target: dyn.DynamicalSystem,

# inputs for target variables
inputs: Sequence = (),
fun_inputs: Callable = None,

# extra info
dt: float = None,
t0: Union[float, int] = 0.,
callback: Callable = None,
**kwargs
):
super(BplRunner, self).__init__(
target=target, inputs=inputs, fun_inputs=fun_inputs,
dt=dt, t0=t0, **kwargs)
self.callback = callback

def f_predict(self, shared_args: Dict = None):
if shared_args is None: shared_args = dict()
if shared_args is None:
shared_args = dict()

shared_kwargs_str = serialize_kwargs(shared_args)
if shared_kwargs_str not in self._f_predict_compiled:
Expand Down
Loading

0 comments on commit f92e3f2

Please sign in to comment.