Skip to content

Commit

Permalink
update metrics and raster plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
zeke committed Aug 5, 2019
1 parent 5717df4 commit 651052b
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 76 deletions.
34 changes: 34 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
backcall==0.1.0
bleach==1.5.0
certifi==2018.10.15
decorator==4.3.0
html5lib==0.9999999
ipykernel==5.1.0
ipython==7.2.0
ipython-genutils==0.2.0
jedi==0.13.2
jupyter-client==5.2.4
jupyter-core==4.4.0
Markdown==2.6.11
mkl-fft==1.0.10
mkl-random==1.0.2
numpy==1.15.4
pandas==0.23.4
parso==0.3.1
pexpect==4.6.0
pickleshare==0.7.5
prompt-toolkit==2.0.7
protobuf==3.6.1
ptyprocess==0.6.0
Pygments==2.3.1
python-dateutil==2.7.5
pytz==2018.7
pyzmq==17.1.2
six==1.12.0
tensorflow==1.4.1
tensorflow-tensorboard==1.5.1
tornado==5.1.1
traitlets==4.3.2
wcwidth==0.1.7
webencodings==0.5.1
Werkzeug==0.14.1
36 changes: 32 additions & 4 deletions swissknife/bci/core/basic_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import scipy as sp
import math
from numba import jit

# fucntions for handling and plotting
def decim(x, q):
Expand Down Expand Up @@ -74,10 +75,13 @@ def plot_raster(x, t1=0, t2=-1, t0=0, ax=None, bin_size=0):
# if bin_size was entered, we want a psth
if bin_size > 0:
psth, t_dec = make_psth(x, t1=t1, t2=t2, t0=t0, bin_size=bin_size)
raster = ax.plot(t_dec, psth)
raster = ax.plot(t_dec, psth, color='C5')
ax.set_ylim(0, max(psth) * 1.2)
stim = ax.plot((0, 0), (0, max(psth) * 1.2), 'k--')
#stim = ax.plot((0, 0), (0, max(psth) * 1.2), 'k--')
#ax.axvline(x=0, color='C6', linestyle=':')
t_max = max(t_dec)
ax.set_ylabel('F.R. (Hz)')
ax.yaxis.set_ticks([int(max(psth)*0.8)])

else:
# Chop the segment
Expand All @@ -101,10 +105,13 @@ def plot_raster(x, t1=0, t2=-1, t0=0, ax=None, bin_size=0):
col = np.arange(events, dtype=np.float)
frame = col[:, np.newaxis] + row[np.newaxis, :]

raster = ax.scatter(t * x, frame * x, marker='|', rasterized=True)
raster = ax.scatter(t * x, frame * x, marker='|', linewidth=0.2,
rasterized=True, color='C3')
ax.set_ylim(0, events + 1)
ax.plot((0, 0), (0, events + 1), 'k--')
#ax.plot((0, 0), (0, events + 1), 'k--')
t_max = t_stamps - t0
ax.set_ylabel('trial')
ax.yaxis.set_ticks([events - 1])

ax.set_xlim(0 - t0, t_max)
return raster, ax
Expand Down Expand Up @@ -164,3 +171,24 @@ def sparse_raster(x, nan=False):
if not nan:
raster[np.isnan(raster)] = 0
return raster

@jit(nopython=True)
def plottable_array(x:np.ndarray, scale:np.ndarray, offset:np.ndarray) -> np.ndarray:
""" Rescale and offset an array for quick plotting multiple channels, along the
1 axis, for each jth axis
Arguments:
x {np.ndarray} -- [n_col x n_row] array (each col is a chan, for instance)
scale {np.ndarray} -- [n_col] vector of scales (typically the ptp values of each row)
offset {np.ndarray} -- [n_col] vector offsets (typycally range (row))
Returns:
np.ndarray -- [n_row x n_col] scaled, offsetted array to plot
"""
# for each row [i]:
# - divide by scale_i
# - add offset_i
n_row, n_col = x.shape
for col in range(n_col):
for row in range(n_row):
x[row, col] = x[row, col] * scale[col] + offset[col]
return x
77 changes: 77 additions & 0 deletions swissknife/bci/core/file/h5_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import h5py
import logging
import os
from tqdm import tqdm_notebook as tqdm

#from mdaio import writemda16i

Expand All @@ -27,6 +28,34 @@ def file_checker(h5_file, *args, **kwargs):

return file_checker

def h5_decorator(default_mode='r'):
"""
Decorator to open h5 structure if the path was provided to a function.
:param h5_function: a function that receives an h5file as first argument
:param default_mode: what mode to open the file by default.
It is overriden when file is entered open and when option 'mode' is set
in h5_function (if it exists)
:return: decorated function that takes open or path as first argument
"""
def wrap(h5_function):
def file_checker(h5_file, *args, **kwargs):
if 'mode' in kwargs.keys():
mode = kwargs['mode']
else:
mode = default_mode
#logger.debug('mode {}'.format(mode))
try:
if type(h5_file) is not h5py._hl.files.File:
with h5py.File(h5_file, mode) as h5_file:
return_value = h5_function(h5_file, *args, **kwargs)
else:
return_value = h5_function(h5_file, *args, **kwargs)
return return_value
except UnboundLocalError as err:
logger.error(err)
raise
return file_checker
return wrap

def list_subgroups(h5_group):
return [key for key, val in h5_group.items() if isinstance(val, h5py.Group)]
Expand Down Expand Up @@ -345,3 +374,51 @@ def list_event_types(kwe_file):

def count_events(kwe_file, ev_type, ev_name, rec=None):
return get_events_one_type(kwe_file, ev_type, ev_name, rec=rec).size

@h5_decorator(default_mode='r')
def collect_frames_fast(kwd_file, recs_list, starts, span, chan_list):
recs = np.unique(recs_list)
all_frames_list = []
for i_rec, rec in tqdm(enumerate(recs), total=recs.size):
starts_from_rec = starts[recs_list == rec]
dset = get_data_set(kwd_file, rec)
n_samples = dset.shape[0]
valid_starts = starts_from_rec[(starts_from_rec > 0)
& (starts_from_rec + span < n_samples)]
if valid_starts.size < starts_from_rec.size:
logger.warn('Some frames were out of bounds and will be discarded')
logger.warn('will collect only {0} events...'.format(
valid_starts.size))

# get the dataset slices for only the channel list
this_rec_frames = get_slice_array(dset, valid_starts, span, chan_list)
all_frames_list.append(this_rec_frames)

try:
all_frames_array = np.concatenate(all_frames_list, axis=0)
except ValueError:
raise
# logger.warn('Failed to collect stream frames, return is nan array')
# zero_dset_shape = get_data_set(kwd_file, rec).shape
# all_frames_array = np.empty([1, *zero_dset_shape])
# all_frames_array[:] = np.nan
return all_frames_array

def get_slice_array(dset: np.ndarray, starts: np.ndarray, span: np.int, chan_list) -> np.ndarray:
n_slices = starts.size
#n_chan = dset.shape[1]
n_chan = chan_list.size
chan_list = np.array(chan_list)
#logger.info('nslice {}, span {}, chan {}'.format(n_slices,n_chan,chan_list))
slices_array = np.zeros([n_slices, span, n_chan])
#logger.info('dset {}, starts {}, span {}, chan_list {}'.format(dset.shape, starts, span, chan_list))
#logger.info('starts {}'.format(starts.dtype))
for i, start in enumerate(starts.astype(np.int64)):
#logger.info('start {}'.format(start.shape))
#logger.info('chan_list {}'.format(chan_list))
#logger.info('i {}'.format(i))
#logger.info('span'.format(span))
#start = np.int(start)
#aux = dset[5: 5 + span, chan_list]
slices_array[i, :, :] = dset[start: start + span, chan_list]
return slices_array
18 changes: 18 additions & 0 deletions swissknife/bci/core/kwik_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,24 @@ def rec_start_array(kwik):
# start_array[i_rec] = get_rec_start_sample(kwik, rec_id)
return np.array(start_array)

@h5f.h5_wrap
def get_rec_sizes(kwd_file):
rec_list = get_rec_list(kwd_file)
rec_sizes = {rec: get_data_size(kwd_file, rec)
for rec in rec_list}
return rec_sizes

@h5f.h5_wrap
def get_rec_starts(kwd_file):
logger.debug('Getting rec_starts')
rec_sizes = get_rec_sizes(kwd_file)
#logger.debug('rec sizes {}'.format(rec_sizes))
starts_vec = np.array(list(rec_sizes.values())).cumsum()
#logger.debug('starts vector {}'.format(starts_vec))
starts_vec = np.hstack([0, starts_vec[:-1]])
rec_starts = {rec: r_start for r_start,
rec in zip(starts_vec, rec_sizes.keys())}
return rec_starts

@h5f.h5_wrap
def get_corresponding_rec(kwik, stamps):
Expand Down
Loading

0 comments on commit 651052b

Please sign in to comment.