Skip to content

Commit

Permalink
Merge pull request #372 from AllenInstitute/update/workshop-2024
Browse files Browse the repository at this point in the history
Update/workshop 2024
  • Loading branch information
kaeldai authored Jun 14, 2024
2 parents b3662cb + 862fd7f commit def7f3e
Show file tree
Hide file tree
Showing 29 changed files with 520 additions and 96 deletions.
50 changes: 38 additions & 12 deletions bmtk/analyzer/ecp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
import h5py
import matplotlib.pyplot as plt
import numpy as np
from decimal import Decimal

from bmtk.utils.sonata.config import SonataConfig
from bmtk.simulator.utils import simulation_reports
# from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm


def _get_ecp_path(ecp_path=None, config=None, report_name=None):
Expand Down Expand Up @@ -55,30 +57,54 @@ def plot_ecp(config_file=None, report_name=None, ecp_path=None, title=None, show
channels = ecp_h5['/ecp/channel_id'][()]
fig, axes = plt.subplots(len(channels), 1)
fig.text(0.04, 0.5, 'channel id', va='center', rotation='vertical')
v_min, v_max = ecp_h5['/ecp/data'][()].min(), ecp_h5['/ecp/data'][()].max()
# print(v_max - v_min)
# exit()

for idx, channel in enumerate(channels):
data = ecp_h5['/ecp/data'][:, idx]
# print(channel, np.min(data), np.max(data))
axes[idx].plot(time_traces, data)
axes[idx].spines["top"].set_visible(False)
axes[idx].spines["right"].set_visible(False)
axes[idx].set_yticks([])
axes[idx].set_ylabel(channel)
axes[idx].set_ylim([v_min, v_max])

if idx+1 != len(channels):
axes[idx].spines["bottom"].set_visible(False)
axes[idx].set_xticks([])
else:
axes[idx].set_xlabel('timestamps (ms)')
# scalebar = AnchoredSizeBar(axes[idx].transData,
# 2.0, '1 mV', 1,
# pad=0,
# borderpad=0,
# # color='b',
# frameon=True,
# # size_vertical=1.001,
# # fontproperties=fontprops
# )
#
# axes[idx].add_artist(scalebar)


if idx == 0:
scale_bar_size = (v_max-v_min)/2.0
scale_bar_label = f'{scale_bar_size:.2E}'
# print(scale_bar_label)
# exit()
fontprops = fm.FontProperties(size='x-small')

scalebar = AnchoredSizeBar(
axes[idx].transData,
size=scale_bar_size,
label=scale_bar_label,
loc='upper right',
pad=0.1,
borderpad=0.5,
sep=5,
# color='b',
frameon=False,
size_vertical=scale_bar_size,
# size_vertical=1.001,
fontproperties=fontprops
)
axes[idx].add_artist(scalebar)

# label = scalebar.txt_label
# label.set_rotation(270.0)
# label.set_verticalalignment('bottom')
# label.set_horizontalalignment('left')

if title:
fig.set_title(title)
Expand Down
3 changes: 2 additions & 1 deletion bmtk/simulator/bionet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
from bmtk.simulator.bionet.pyfunction_cache import synapse_model, synaptic_weight, cell_model, add_weight_function, model_processing
from bmtk.simulator.bionet.pyfunction_cache import synapse_model, synaptic_weight, cell_model, add_weight_function, model_processing, \
spikes_generator
from bmtk.simulator.bionet.config import Config
from bmtk.simulator.bionet.bionetwork import BioNetwork
from bmtk.simulator.bionet.biosimulator import BioSimulator
Expand Down
45 changes: 37 additions & 8 deletions bmtk/simulator/bionet/biocell.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from bmtk.simulator.bionet.morphology import Morphology
import six

import neuron
from neuron import h

pc = h.ParallelContext() # object to access MPI methods
Expand Down Expand Up @@ -74,9 +75,6 @@ class BioCell(Cell):
def __init__(self, node, population_name, bionetwork):
super(BioCell, self).__init__(node=node, population_name=population_name, network=bionetwork)

# Set up netcon object that can be used to detect and communicate cell spikes.
self.set_spike_detector(bionetwork.spike_threshold)

# Determine number of segments and store a list of all sections.
self._secs = []
self._secs_by_id = []
Expand Down Expand Up @@ -105,6 +103,10 @@ def __init__(self, node, population_name, bionetwork):
self._seg_coords = None
self.build_morphology()

# Set up netcon object that can be used to detect and communicate cell spikes.
self.set_spike_detector(bionetwork.spike_threshold)


def build_morphology(self):
morph_base = Morphology.load(hobj=self.hobj, morphology_file=self.morphology_file, cache_seg_props=True)

Expand All @@ -126,6 +128,10 @@ def morphology(self):
"""The actual Morphology object instanstiation"""
return self._morphology

@property
def soma(self):
return self.morphology.soma

@property
def seg_coords(self):
"""Coordinates for segments/sections of the morphology, need to make public for ecp, xstim, and other
Expand All @@ -144,7 +150,7 @@ def seg_coords(self):
return self.morphology.seg_coords

def set_spike_detector(self, spike_threshold):
nc = h.NetCon(self.hobj.soma[0](0.5)._ref_v, None, sec=self.hobj.soma[0]) # attach spike detector to cell
nc = h.NetCon(self.soma[0](0.5)._ref_v, None, sec=self.soma[0])
nc.threshold = spike_threshold
pc.cell(self.gid, nc) # associate gid with spike detector

Expand Down Expand Up @@ -437,18 +443,41 @@ def __init__(self, node, population_name, bionetwork):
self._vecstim = h.VecStim()
self._vecstim.play(self._spike_trains)

self._precell_filter = bionetwork.spont_syns_filter
self._precell_filter = bionetwork.spont_syns_filter_pre
self._postcell_filter = bionetwork.spont_syns_filter_post
assert(isinstance(self._precell_filter, dict))

def _matches_filter(self, src_node):
def _matches_filter(self, src_node, trg_node=None):
"""Check to see if the presynaptic cell matches the criteria specified"""
for k, v in self._precell_filter.items():
# Some key may not show up as node_variable
if k == 'population' and k not in src_node:
key_val = src_node.population_name
else:
key_val = src_node[k]

if isinstance(v, (list, tuple)):
if key_val not in v:
return False
else:
if key_val != v:
return False

trg_node = trg_node or self
for k, v in self._postcell_filter.items():
# Some key may not show up as node_variable
if k == 'population' and k not in trg_node:
key_val = trg_node._node.population_name
else:
key_val = trg_node[k]

if isinstance(v, (list, tuple)):
if src_node[k] not in v:
if key_val not in v:
return False
else:
if src_node[k] != v:
if key_val != v:
return False

return True

def _set_connections(self, edge_prop, src_node, syn_weight, stim=None):
Expand Down
18 changes: 10 additions & 8 deletions bmtk/simulator/bionet/bionetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def __init__(self):
self._gid_pool = GidPool()

self.has_spont_syns = False
self.spont_syns_filter = None
self.spont_syns_filter_pre = None
self.spont_syns_filter_post = None
self.spont_syns_times = None

@property
Expand All @@ -88,7 +89,7 @@ def gid_pool(self):
def py_function_caches(self):
return nrn

def set_spont_syn_activity(self, precell_filter, timestamps):
def set_spont_syn_activity(self, precell_filter, postcell_filter, timestamps):
self._model_type_map = {
'biophysical': BioCellSpontSyn,
'point_process': PointProcessCellSpontSyns,
Expand All @@ -98,7 +99,8 @@ def set_spont_syn_activity(self, precell_filter, timestamps):
}

self.has_spont_syns = True
self.spont_syns_filter = precell_filter
self.spont_syns_filter_pre = precell_filter
self.spont_syns_filter_post = postcell_filter
self.spont_syns_times = timestamps

def get_node_id(self, population, node_id):
Expand Down Expand Up @@ -134,12 +136,12 @@ def add_nodes(self, node_population):
self._gid_pool.add_pool(node_population.name, node_population.n_nodes())
super(BioNetwork, self).add_nodes(node_population)

def get_virtual_cells(self, population, node_id, spike_trains):
def get_virtual_cells(self, population, node_id, spike_trains, spikes_generator=None, sim=None):
if node_id in self._virtual_nodes[population]:
return self._virtual_nodes[population][node_id]
else:
node = self.get_node_id(population, node_id)
virt_cell = VirtualCell(node, population, spike_trains)
virt_cell = VirtualCell(node, population, spike_trains, spikes_generator, sim)
self._virtual_nodes[population][node_id] = virt_cell
return virt_cell

Expand All @@ -151,7 +153,7 @@ def get_disconnected_cell(self, population, node_id, spike_trains):
virt_cell = self._disconnected_source_cells[population][node_id]
else:
node = self.get_node_id(population, node_id)
virt_cell = VirtualCell(node, population, spike_trains)
virt_cell = VirtualCell(node, population, spike_trains, self)
self._disconnected_source_cells[population][node_id] = virt_cell

return virt_cell
Expand Down Expand Up @@ -369,7 +371,7 @@ def find_edges(self, source_nodes=None, target_nodes=None):

return selected_edges

def add_spike_trains(self, spike_trains, node_set):
def add_spike_trains(self, spike_trains, node_set, spikes_generator=None, sim=None):
self._init_connections()

src_nodes = [node_pop for node_pop in self.node_populations if node_pop.name in node_set.population_names()]
Expand All @@ -379,7 +381,7 @@ def add_spike_trains(self, spike_trains, node_set):
if edge_pop.virtual_connections:
for trg_nid, trg_cell in self._rank_node_ids[edge_pop.target_nodes].items():
for edge in edge_pop.get_target(trg_nid):
src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains)
src_cell = self.get_virtual_cells(source_population, edge.source_node_id, spike_trains, spikes_generator, sim)
trg_cell.set_syn_connection(edge, src_cell, src_cell)

elif edge_pop.mixed_connections:
Expand Down
22 changes: 20 additions & 2 deletions bmtk/simulator/bionet/biosimulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ def from_config(cls, config, network, set_recordings=True):
if sim_input.input_type == 'syn_activity':
network.set_spont_syn_activity(
precell_filter=sim_input.params['precell_filter'],
postcell_filter=sim_input.params.get('postcell_filter', {}),
timestamps=sim_input.params['timestamps']
)

Expand All @@ -346,13 +347,30 @@ def from_config(cls, config, network, set_recordings=True):

# TODO: Need to create a gid selector
for sim_input in inputs.from_config(config):
if sim_input.input_type == 'spikes' and sim_input.module in ['nwb', 'csv', 'sonata']:
if sim_input.input_type == 'spikes' and sim_input.module in ['nwb', 'csv', 'sonata', 'h5']:
io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
path = sim_input.params['input_file']
spikes = SpikeTrains.load(path=path, file_type=sim_input.module, **sim_input.params)
# node_set_opts = sim_input.params.get('node_set', 'all')
node_set = network.get_node_set(sim_input.node_set)
network.add_spike_trains(spikes, node_set)
network.add_spike_trains(
spike_trains=spikes,
node_set=node_set,
spikes_generator=None,
sim=sim
)

elif sim_input.input_type == 'spikes' and sim_input.module == 'function':
io.log_info('Building virtual cell stimulations for {}'.format(sim_input.name))
# path = sim_input.params.get['input_file']
spikes_generator = sim_input.params['spikes_function']
node_set = network.get_node_set(sim_input.node_set)
network.add_spike_trains(
spike_trains=None,
node_set=node_set,
spikes_generator=spikes_generator,
sim=sim
)

elif sim_input.module == 'IClamp':
sim.add_mod(mods.IClampMod(input_type=sim_input.input_type, **sim_input.params))
Expand Down
6 changes: 6 additions & 0 deletions bmtk/simulator/bionet/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def get_connection_info(self):

def set_syn_connections(self, edge_prop, src_node, stim=None):
raise NotImplementedError

def get_section(self, sec_name, sec_index):
raise NotImplementedError

def __contains__(self, node_prop):
return node_prop in self._node

def __getitem__(self, node_prop):
return self._node[node_prop]
12 changes: 11 additions & 1 deletion bmtk/simulator/bionet/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ def create_output_dir(self):
io.setup_output_dir(self.output_dir, self.log_file)

def load_nrn_modules(self):
nrn.load_neuron_modules(self.mechanisms_dir, self.templates_dir)
nrn.load_neuron_modules(
mechanisms_dir=self.mechanisms_dir,
templates_dir=self.templates_dir,
default_templates=self.use_default_templates,
use_old_import3d=self.use_old_import3d
)

def build_env(self):
self.io = io
Expand All @@ -52,3 +57,8 @@ def build_env(self):

pc.barrier()
self.load_nrn_modules()

def _set_class_props(self):
super(Config, self)._set_class_props()
self.use_old_import3d = self.run.get('use_old_import3d', False)
self.use_default_templates = self.run.get('use_old_import3d', True)
52 changes: 40 additions & 12 deletions bmtk/simulator/bionet/default_setters/cell_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import numpy as np
from neuron import h
import inspect
try:
from sklearn.decomposition import PCA
except Exception as e:
Expand All @@ -41,19 +42,46 @@
"""

def loadHOC(cell, template_name, dynamics_params):
# Get template to instantiate
template_call = getattr(h, template_name)
if dynamics_params is not None and 'params' in dynamics_params:
template_params = dynamics_params['params']
if isinstance(template_params, list):
# pass in a list of parameters
hobj = template_call(*template_params)
"""A Generic function for creating a cell object from a NEURON HOC Template (eg. a *.hoc file with
`begintemplate template_name` in header). It essentially tries to guess the correct parameters that need to be
called so may not work the majority of the times and require to be overloaded.
:param cell: A SONATA node object, can be used as a dict to get individual properties of current cell.
:param template_name: name of HOCTemplate as stored in "model_template" attribute (hoc:<template_name>).
:param dynamics_params: Dictionary containing contents of cell['dynamics_params'] as loaded from a json file or hdf5.
If cell does not have "dynamics_params" attributes then will be set to None.
"""
try:
# Get template to instantiate
template_call = getattr(h, template_name)
except AttributeError as ae:
io.log_error(
f'loadHOC was unable to load in Neuron HOC Template "{template_name}, '
'Make sure appropiate .hoc file is stored in templates_dir.'
)
raise ae

try:
if dynamics_params is not None and 'params' in dynamics_params:
template_params = dynamics_params['params']
if isinstance(template_params, list):
# pass in a list of parameters
hobj = template_call(*template_params)
else:
# only a single parameter
hobj = template_call(template_params)
elif cell.morphology_file is not None:
# instantiate template with no parameters
hobj = template_call(cell.morphology_file)
else:
# only a single parameter
hobj = template_call(template_params)
else:
# instantiate template with no parameters
hobj = template_call()
hobj = template_call()
except RuntimeError as rte:
io.log_error(
f'bmtk.simualtor.bionet.default_setters.cell_models.loadHOC function failed to load HOC template "{template_call}". '
'If Hoc Templates requires special call to be initialized consider using `bmtk.simulator.bionet.add_cell_model()` '
'to overwrite this function.'
)
raise rte

# TODO: All "all" section if it doesn't exist
# hobj.all = h.SectionList()
Expand Down
Loading

0 comments on commit def7f3e

Please sign in to comment.