Skip to content

Commit

Permalink
[FIX] Added a new example using memristive reservoirs and fixed bug w…
Browse files Browse the repository at this point in the history
…ith plotting module (#47)
  • Loading branch information
estefanysuarez authored May 17, 2024
1 parent 6c87525 commit cfb78b4
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 23 deletions.
30 changes: 13 additions & 17 deletions conn2res/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,25 @@
""",
title="""\
title : str, optional
Title to be shown at the superior part of the figure.
Title to be shown at the superior part of the figure.\
""",
show="""\
show : bool, optional
If True, it will display the matplotlib.pyplot.figure object
If True, it will display the matplotlib.pyplot.figure object.\
""",
savefig="""\
savefig : bool, optional
If True, it will save the matploblib.pyplot.figure object as a '.png' file by default.
The format of the file can be changed using the 'savefig.format'
keyword in the rc_params argument.
The format of the file can be changed using the 'savefig.format'.
keyword in the rc_params argument.\
""",
fname="""\
fname : str or path-like
Path where the figure will be saved.
Path where the figure will be saved.\
""",
kwargs="""\
kwargs : key-value pairs
Other keyword arguments pass directly to the underlying seaborn plotting function.
Other keyword arguments pass directly to the underlying seaborn plotting function.\
"""
)

Expand Down Expand Up @@ -316,9 +316,7 @@ def plot_reservoir_states(
_description_, by default None
{rc_params}
{fig_params}
ax_params : list of dict
list of dictionaries with keyword arguments for `matplotlib.pyplot.axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes`__.
Values to set axes's properties, by default [{}] * 2.
{ax_params}
{lg_params}
{title}
{show}
Expand Down Expand Up @@ -466,12 +464,8 @@ def plot_diagnostics(
_description_, by default None
{rc_params}
{fig_params}
ax_params : list of dict
list of dictionaries with keyword arguments for `matplotlib.pyplot.axes <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.html#matplotlib.axes.Axes`__.
Values to set axes's properties, by default [{}] * 3.
lg_params : list of dict
list of dictionaries with keyword arguments for `matplotlib.axes.Axes.legend <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.legend.html`__.
Values to set legend's properties, by default [{}] * 3.
{ax_params}
{lg_params}
{title}
{show}
{savefig}
Expand Down Expand Up @@ -757,7 +751,8 @@ def plot_phase_space(
# reset rc defaults
mpl.rcdefaults()

def plot_spike_raster(tspike, x1, x2, title = "Spike Raster"):

def plot_spike_raster(tspike, x1, x2, title="Spike Raster"):
"""
Plot a spike raster plot.
Expand Down Expand Up @@ -798,7 +793,8 @@ def plot_spike_raster(tspike, x1, x2, title = "Spike Raster"):
plt.grid(True, linestyle='--', alpha = 0.7)
plt.show()

def plot_membrane_voltages(membrane_voltages, x1, x2, neuron_idx = None,

def plot_membrane_voltages(membrane_voltages, x1, x2, neuron_idx=None,
dt = 0.05, title="Membrane Voltages"):
"""
Plot the membrane voltages of the neurons.
Expand Down
12 changes: 6 additions & 6 deletions conn2res/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,7 +986,7 @@ def getV(self, Vi, Ve, Vgr=None):
else:
V[i, j] = nv_dict[j] - nv_dict[i]

return mask(self, V)
return self.mask(V)

def simulate(self, Vext, ic=None, mode='forward'):
"""
Expand All @@ -1013,8 +1013,8 @@ def simulate(self, Vext, ic=None, mode='forward'):
N: number of nodes in the network
"""

# print('\n GENERATING RESERVOIR STATES ...')
# print(f'\n SIMULATING STATES IN {mode.upper()} MODE ...')
print('\n GENERATING RESERVOIR STATES ...')
print(f'\n SIMULATING STATES IN {mode.upper()} MODE ...')

# initialize reservoir states
self._state = np.zeros((len(Vext), self._n_nodes))
Expand All @@ -1027,7 +1027,7 @@ def simulate(self, Vext, ic=None, mode='forward'):
for t, Ve in enumerate(Vext):
if mode == 'forward':

if (t > 0) and (t % 100 == 0):
if (t > 0) and (t % 2 == 0):
print(f'\t ----- timestep = {t}')

# get voltage at internal nodes
Expand Down Expand Up @@ -1306,8 +1306,8 @@ def dG(self, V, G=None, dt=1e-4, seed=None):
# use random number generator for reproducibility
rng = np.random.default_rng(seed=seed)

Gab = rng.binomial(Na.astype(int), mask(self, Pa))
Gba = rng.binomial(Nb.astype(int), mask(self, Pb))
Gab = rng.binomial(Na.astype(int), self.mask(Pa))
Gba = rng.binomial(Nb.astype(int), self.mask(Pb))

if utils.check_symmetric(self._W):
Gab = utils.make_symmetric(Gab)
Expand Down
230 changes: 230 additions & 0 deletions examples/example4_sims.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# -*- coding: utf-8 -*-
"""
Connectome-informed reservoir - Memristive Network
=================================================
This example demonstrates how to use the conn2res toolbox
to perform a task using a human connectomed-informed
Memristive network
"""
import warnings
import os
import numpy as np
import pandas as pd
from conn2res.tasks import Conn2ResTask
from conn2res.connectivity import Conn
from conn2res.reservoir import MSSNetwork
from conn2res.readout import Readout
from conn2res import readout, plotting

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

# #####################################################################
# First, let's initialize some constant variables
# #####################################################################
# project and figure directory
PROJ_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
OUTPUT_DIR = os.path.join(PROJ_DIR, 'figs')
if not os.path.isdir(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)

# number of runs for each task
N_RUNS = 1

# name of the tasks to be performed
TASKS = [
'MemoryCapacity'
]

# define metrics to evaluate readout's model performance
METRICS = [
'corrcoef',
]

# define alpha values to vary global reservoir dynamics
ALPHAS = [1.0] # np.linspace(0, 2, 41)[1:]

for task_name in TASKS:

print(f'\n---------------TASK: {task_name.upper()}---------------')

OUTPUT_DIR = os.path.join(PROJ_DIR, 'figs', task_name)
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)

# #####################################################################
# Second, let's create an instance of a NeuroGym task. To do so we need
# the name of task.
# #####################################################################
task = Conn2ResTask(name=task_name)

# #####################################################################
# Third, let's import the connectivity matrix we are going to use to
# define the connections of the reservoir. For this we will be using
# the human connectome parcellated into 1015 brain regions following
# the Desikan Killiany atlas (Desikan, et al., 2006).
# #####################################################################

# load connectivity data of one subject
conn = Conn(subj_id=0)

# scale conenctivity weights between [0,1] and normalize by spectral its
# radius
conn.scale_and_normalize()

# #####################################################################
# Next, we will simulate the dynamics of the reservoir. We will evaluate
# the effect of local network dynamics by using different activation
# functions. We will also evaluate network performance across dynamical
# regimes by parametrically tuning alpha, which corresponds to the
# spectral radius of the connectivity matrix (alpha parameter).
# #####################################################################
df_runs = []
for run in range(N_RUNS):
print(f'\n\t\t--- run = {run} ---')

# fetch data to perform task
x, y = task.fetch_data(n_trials=500, input_gain=1)

# visualize task dataset
if run == 0:
plotting.plot_iodata(
x, y, title=task.name, savefig=True,
fname=os.path.join(OUTPUT_DIR, f'io_{task.name}'),
rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
show=False
)

# split data into training and test sets
x_train, x_test, y_train, y_test = readout.train_test_split(x, y)

# We will define the set of external, internal and ground nodes. We
# will also define the set of readout nodes, which will be the ones
# to be actually used to perform the task.
gr_nodes = conn.get_nodes(
'random',
nodes_from=conn.get_nodes('ctx'),
n_nodes=1
) # we select a single random node from ctx - GROUND

ext_nodes = conn.get_nodes(
'random',
nodes_from=conn.get_nodes('subctx'),
n_nodes=task.n_features
) # we select a random set of nodes from subctx - EXTERNAL/INPUT

int_nodes = conn.get_nodes(
'all',
nodes_without=np.union1d(gr_nodes, ext_nodes),
n_nodes=task.n_features
) # we select the reamining ctx and subctx - INTERNAL

output_nodes = conn.get_nodes(
'ctx',
nodes_without=gr_nodes,
n_nodes=task.n_features
) # we use the reamining ctx regions - READOUT/OUTPUT

# instantiate an Metastable Switch Memristive network object
mssn = MSSNetwork(
w=conn.w,
int_nodes=int_nodes,
ext_nodes=ext_nodes,
gr_nodes=gr_nodes,
mode='backward'
)

# instantiate a Readout object
readout_module = Readout(estimator=readout.select_model(y))

# defined performance metrics based on Readout's type of model
metrics = METRICS

# iterate global dynamics using different alpha values
df_alpha = []
for alpha in ALPHAS:

print(f'\n\t\t\t----- alpha = {alpha} -----')

# scale connectivity matrix by alpha
mssn.w = alpha * conn.w

# simulate reservoir states
rs_train = mssn.simulate(
Vext=x_train
)

rs_test = mssn.simulate(
Vext=x_test,
)

# visualize reservoir states
if run == 0 and alpha == 1.0:
plotting.plot_reservoir_states(
x=x_train, reservoir_states=rs_train,
title=task.name,
savefig=True,
fname=os.path.join(OUTPUT_DIR, f'res_states_train_{task.name}'),
rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
show=False
)
plotting.plot_reservoir_states(
x=x_test, reservoir_states=rs_test,
title=task.name,
savefig=True,
fname=os.path.join(OUTPUT_DIR, f'res_states_test_{task.name}'),
rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
show=False
)

# perform task
df_res = readout_module.run_task(
X=(rs_train, rs_test), y=(y_train, y_test),
sample_weight='both', metric=metrics,
readout_modules=None, readout_nodes=None,
)

# assign column with alpha value and append df_res
# to df_alpha
df_res['alpha'] = np.round(alpha, 3)
df_alpha.append(df_res)

# concatenate results across alpha values and append
# df_alpha to df_runs
df_alpha = pd.concat(df_alpha, ignore_index=True)
df_alpha['run'] = run
df_runs.append(df_alpha)

# concatenate results across runs and append
# df_runs to df_subj
df_runs = pd.concat(df_runs, ignore_index=True)
if 'module' in df_runs.columns:
df_subj = df_runs[
['module', 'n_nodes', 'run', 'alpha'] + metrics
]
else:
df_subj = df_runs[
['run', 'alpha'] + metrics
]

df_subj.to_csv(
os.path.join(OUTPUT_DIR, f'results_{task.name}.csv'),
index=False
)

###########################################################################
# visualize performance curve
df_subj = pd.read_csv(
os.path.join(OUTPUT_DIR, f'results_{task.name}.csv'),
index_col=False
)

for metric in metrics:
plotting.plot_performance(
df_subj, x='alpha', y=metric,
title=task.name, savefig=True,
fname=os.path.join(OUTPUT_DIR, f'perf_{task.name}_{metric}'),
rc_params={'figure.dpi': 300, 'savefig.dpi': 300},
show=False
)

0 comments on commit cfb78b4

Please sign in to comment.