diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5a5a22a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,34 @@ +backcall==0.1.0 +bleach==1.5.0 +certifi==2018.10.15 +decorator==4.3.0 +html5lib==0.9999999 +ipykernel==5.1.0 +ipython==7.2.0 +ipython-genutils==0.2.0 +jedi==0.13.2 +jupyter-client==5.2.4 +jupyter-core==4.4.0 +Markdown==2.6.11 +mkl-fft==1.0.10 +mkl-random==1.0.2 +numpy==1.15.4 +pandas==0.23.4 +parso==0.3.1 +pexpect==4.6.0 +pickleshare==0.7.5 +prompt-toolkit==2.0.7 +protobuf==3.6.1 +ptyprocess==0.6.0 +Pygments==2.3.1 +python-dateutil==2.7.5 +pytz==2018.7 +pyzmq==17.1.2 +six==1.12.0 +tensorflow==1.4.1 +tensorflow-tensorboard==1.5.1 +tornado==5.1.1 +traitlets==4.3.2 +wcwidth==0.1.7 +webencodings==0.5.1 +Werkzeug==0.14.1 diff --git a/swissknife/bci/core/basic_plot.py b/swissknife/bci/core/basic_plot.py index 488a380..85aee4f 100755 --- a/swissknife/bci/core/basic_plot.py +++ b/swissknife/bci/core/basic_plot.py @@ -3,6 +3,7 @@ import numpy as np import scipy as sp import math +from numba import jit # fucntions for handling and plotting def decim(x, q): @@ -74,10 +75,13 @@ def plot_raster(x, t1=0, t2=-1, t0=0, ax=None, bin_size=0): # if bin_size was entered, we want a psth if bin_size > 0: psth, t_dec = make_psth(x, t1=t1, t2=t2, t0=t0, bin_size=bin_size) - raster = ax.plot(t_dec, psth) + raster = ax.plot(t_dec, psth, color='C5') ax.set_ylim(0, max(psth) * 1.2) - stim = ax.plot((0, 0), (0, max(psth) * 1.2), 'k--') + #stim = ax.plot((0, 0), (0, max(psth) * 1.2), 'k--') + #ax.axvline(x=0, color='C6', linestyle=':') t_max = max(t_dec) + ax.set_ylabel('F.R. (Hz)') + ax.yaxis.set_ticks([int(max(psth)*0.8)]) else: # Chop the segment @@ -101,10 +105,13 @@ def plot_raster(x, t1=0, t2=-1, t0=0, ax=None, bin_size=0): col = np.arange(events, dtype=np.float) frame = col[:, np.newaxis] + row[np.newaxis, :] - raster = ax.scatter(t * x, frame * x, marker='|', rasterized=True) + raster = ax.scatter(t * x, frame * x, marker='|', linewidth=0.2, + rasterized=True, color='C3') ax.set_ylim(0, events + 1) - ax.plot((0, 0), (0, events + 1), 'k--') + #ax.plot((0, 0), (0, events + 1), 'k--') t_max = t_stamps - t0 + ax.set_ylabel('trial') + ax.yaxis.set_ticks([events - 1]) ax.set_xlim(0 - t0, t_max) return raster, ax @@ -164,3 +171,24 @@ def sparse_raster(x, nan=False): if not nan: raster[np.isnan(raster)] = 0 return raster + +@jit(nopython=True) +def plottable_array(x:np.ndarray, scale:np.ndarray, offset:np.ndarray) -> np.ndarray: + """ Rescale and offset an array for quick plotting multiple channels, along the + 1 axis, for each jth axis + Arguments: + x {np.ndarray} -- [n_col x n_row] array (each col is a chan, for instance) + scale {np.ndarray} -- [n_col] vector of scales (typically the ptp values of each row) + offset {np.ndarray} -- [n_col] vector offsets (typycally range (row)) + + Returns: + np.ndarray -- [n_row x n_col] scaled, offsetted array to plot + """ + # for each row [i]: + # - divide by scale_i + # - add offset_i + n_row, n_col = x.shape + for col in range(n_col): + for row in range(n_row): + x[row, col] = x[row, col] * scale[col] + offset[col] + return x \ No newline at end of file diff --git a/swissknife/bci/core/file/h5_functions.py b/swissknife/bci/core/file/h5_functions.py index 9b10469..3400058 100755 --- a/swissknife/bci/core/file/h5_functions.py +++ b/swissknife/bci/core/file/h5_functions.py @@ -5,6 +5,7 @@ import h5py import logging import os +from tqdm import tqdm_notebook as tqdm #from mdaio import writemda16i @@ -27,6 +28,34 @@ def file_checker(h5_file, *args, **kwargs): return file_checker +def h5_decorator(default_mode='r'): + """ + Decorator to open h5 structure if the path was provided to a function. + :param h5_function: a function that receives an h5file as first argument + :param default_mode: what mode to open the file by default. + It is overriden when file is entered open and when option 'mode' is set + in h5_function (if it exists) + :return: decorated function that takes open or path as first argument + """ + def wrap(h5_function): + def file_checker(h5_file, *args, **kwargs): + if 'mode' in kwargs.keys(): + mode = kwargs['mode'] + else: + mode = default_mode + #logger.debug('mode {}'.format(mode)) + try: + if type(h5_file) is not h5py._hl.files.File: + with h5py.File(h5_file, mode) as h5_file: + return_value = h5_function(h5_file, *args, **kwargs) + else: + return_value = h5_function(h5_file, *args, **kwargs) + return return_value + except UnboundLocalError as err: + logger.error(err) + raise + return file_checker + return wrap def list_subgroups(h5_group): return [key for key, val in h5_group.items() if isinstance(val, h5py.Group)] @@ -345,3 +374,51 @@ def list_event_types(kwe_file): def count_events(kwe_file, ev_type, ev_name, rec=None): return get_events_one_type(kwe_file, ev_type, ev_name, rec=rec).size + +@h5_decorator(default_mode='r') +def collect_frames_fast(kwd_file, recs_list, starts, span, chan_list): + recs = np.unique(recs_list) + all_frames_list = [] + for i_rec, rec in tqdm(enumerate(recs), total=recs.size): + starts_from_rec = starts[recs_list == rec] + dset = get_data_set(kwd_file, rec) + n_samples = dset.shape[0] + valid_starts = starts_from_rec[(starts_from_rec > 0) + & (starts_from_rec + span < n_samples)] + if valid_starts.size < starts_from_rec.size: + logger.warn('Some frames were out of bounds and will be discarded') + logger.warn('will collect only {0} events...'.format( + valid_starts.size)) + + # get the dataset slices for only the channel list + this_rec_frames = get_slice_array(dset, valid_starts, span, chan_list) + all_frames_list.append(this_rec_frames) + + try: + all_frames_array = np.concatenate(all_frames_list, axis=0) + except ValueError: + raise + # logger.warn('Failed to collect stream frames, return is nan array') + # zero_dset_shape = get_data_set(kwd_file, rec).shape + # all_frames_array = np.empty([1, *zero_dset_shape]) + # all_frames_array[:] = np.nan + return all_frames_array + +def get_slice_array(dset: np.ndarray, starts: np.ndarray, span: np.int, chan_list) -> np.ndarray: + n_slices = starts.size + #n_chan = dset.shape[1] + n_chan = chan_list.size + chan_list = np.array(chan_list) + #logger.info('nslice {}, span {}, chan {}'.format(n_slices,n_chan,chan_list)) + slices_array = np.zeros([n_slices, span, n_chan]) + #logger.info('dset {}, starts {}, span {}, chan_list {}'.format(dset.shape, starts, span, chan_list)) + #logger.info('starts {}'.format(starts.dtype)) + for i, start in enumerate(starts.astype(np.int64)): + #logger.info('start {}'.format(start.shape)) + #logger.info('chan_list {}'.format(chan_list)) + #logger.info('i {}'.format(i)) + #logger.info('span'.format(span)) + #start = np.int(start) + #aux = dset[5: 5 + span, chan_list] + slices_array[i, :, :] = dset[start: start + span, chan_list] + return slices_array \ No newline at end of file diff --git a/swissknife/bci/core/kwik_functions.py b/swissknife/bci/core/kwik_functions.py index ef6fbd6..fe004e7 100755 --- a/swissknife/bci/core/kwik_functions.py +++ b/swissknife/bci/core/kwik_functions.py @@ -169,6 +169,24 @@ def rec_start_array(kwik): # start_array[i_rec] = get_rec_start_sample(kwik, rec_id) return np.array(start_array) +@h5f.h5_wrap +def get_rec_sizes(kwd_file): + rec_list = get_rec_list(kwd_file) + rec_sizes = {rec: get_data_size(kwd_file, rec) + for rec in rec_list} + return rec_sizes + +@h5f.h5_wrap +def get_rec_starts(kwd_file): + logger.debug('Getting rec_starts') + rec_sizes = get_rec_sizes(kwd_file) + #logger.debug('rec sizes {}'.format(rec_sizes)) + starts_vec = np.array(list(rec_sizes.values())).cumsum() + #logger.debug('starts vector {}'.format(starts_vec)) + starts_vec = np.hstack([0, starts_vec[:-1]]) + rec_starts = {rec: r_start for r_start, + rec in zip(starts_vec, rec_sizes.keys())} + return rec_starts @h5f.h5_wrap def get_corresponding_rec(kwik, stamps): diff --git a/swissknife/bci/units.py b/swissknife/bci/units.py index 1c69b7c..d5bf6da 100755 --- a/swissknife/bci/units.py +++ b/swissknife/bci/units.py @@ -51,7 +51,7 @@ def __init__(self, clu, h5_file=None): self.get_time_stamps() def get_rec_offsets(self): - self.recording_offsets = kf.rec_start_array(self.h5_file) + self.recording_offsets = kf.get_rec_starts(self.h5_file) return self.recording_offsets def get_sampling_rate(self): @@ -261,6 +261,7 @@ def __init__(self, clu, group=0, kwik_file=None, sort=0): self.avg_waveform = None self.main_chan = None self.main_wave = None + self.isi_hist_tuple = None self.waveform_pars = {} if kwik_file is not None: @@ -278,19 +279,19 @@ def get_time_stamps(self): t_path = "/channel_groups/{0:d}/spikes/time_samples".format(self.group) r_path = "/channel_groups/{0:d}/spikes/recording".format(self.group) - dtype = self.kwik_file[t_path].dtype # time samples are relative to the beginning of the corresponding rec - time_samples = np.array(self.kwik_file[t_path][self.kwik_file[clu_path][:] == self.clu], - dtype=np.dtype(dtype)) + this_clu = self.kwik_file[clu_path][:] == self.clu + all_t = self.kwik_file[t_path][:] + all_rec = self.kwik_file[r_path][:] + + dtype = self.kwik_file[t_path].dtype + self.time_samples = np.array( + all_t[this_clu], dtype=np.dtype(dtype)) dtype = self.kwik_file[r_path].dtype # recordings ids (as in the key) - recordings = np.array(self.kwik_file[r_path][self.kwik_file[clu_path][:] == self.clu], - dtype=np.dtype(dtype)) - - # patch for a random kilosort error that throws a random 0 for a time_stamp - self.time_samples = time_samples[time_samples > 0] - self.recordings = recordings[time_samples > 0] + self.recordings = np.array( + all_rec[this_clu], dtype=np.dtype(dtype)) return self.time_samples, self.recordings def get_rec_offsets(self): @@ -367,7 +368,7 @@ def get_isi(self): return all_isi_ms - def get_isi_dist(self, bin_size_ms=1, max_t=100): + def get_isi_dist(self, bin_size_ms=0.5, max_t=100, one_sided=False): if self.time_samples is None: self.get_time_stamps() all_isi_ms = np.round(np.diff(self.time_samples)/(self.sampling_rate * 0.001)) @@ -378,7 +379,24 @@ def get_isi_dist(self, bin_size_ms=1, max_t=100): bins = bins[1:] two_side_bins = np.concatenate([-bins[::-1], bins[1:]]) two_side_hist = np.concatenate([hist[::-1], hist[1:]]) - return two_side_hist, two_side_bins + + + self.isi_hist_tuple = (hist, bins) + if one_sided: + return bins, hist + else: + return two_side_hist, two_side_bins + + def get_isi_violations(self, threshold=0.05, refractory_ms=1): + if self.isi_hist_tuple is None: + self.get_isi_dist() + hist, bins = self.isi_hist_tuple + + total_spikes = np.sum(hist) + violations = np.sum(hist[bins <= refractory_ms]) + violations_ratio = violations/total_spikes + return violations_ratio, violations_ratio < threshold + def get_folder(self): return os.path.split(os.path.abspath(self.kwik_file.filename))[0] @@ -451,7 +469,7 @@ def get_principal_channels(self, projectors=4): principal_channels = pc_ind[self.clu][principal_projections] return principal_channels, principal_projections - def get_unit_spikes(self, before=20, after=20, only_principal=False): + def get_unit_spikes(self, before=20, after=20, only_principal=False, max_events=5000): s_f = self.sampling_rate valid_times = self.time_samples[self.time_samples > before] valid_recs = self.recordings[self.time_samples > before] @@ -460,21 +478,35 @@ def get_unit_spikes(self, before=20, after=20, only_principal=False): logger.warn('Some frames were out of left bounds and will be discarded') logger.warn('will collect only {0} events...'.format(valid_times.size)) - if only_principal: - chan_list = self.get_principal_channels() - else: - chan_list = self.get_unit_chans() + + chan_list = self.get_unit_chans() self.waveform_pars = {'before': before, 'after': after, 'chan_list': np.array(chan_list)} - - self.all_waveforms = collect_frames_fast(valid_times - before, - before + after, - s_f, - self.get_kwd_path(), - valid_recs, - np.array(chan_list)) + + + try: + assert valid_times.size > 1, 'no valid events' + # get a random sample of max_events elements + sample = np.random.choice(np.arange(valid_times.size), + size=min(max_events, valid_times.size), + replace=False) + self.all_waveforms = h5f.collect_frames_fast(self.get_kwd_path(), + valid_recs[sample], + valid_times[sample] - + before, + before + after, + np.array(chan_list)) + except (ValueError, AssertionError) as err: + logger.warn( + 'Could not retrieve waveforms for clu {}, error {}'.format(self.clu, + err)) + self.all_waveforms = np.zeros( + [1, before + after, np.array(chan_list).size]) + self.all_waveforms[:] = np.nan + self.save_unit_spikes() + return self.waveforms def load_all_waveforms(self): @@ -523,7 +555,7 @@ def get_unit_main_chans(self, n_chans=1): main_chan_absolute = np.array(self.waveform_pars['chan_list'])[main_chans] return main_chans.astype(np.int), main_chan_absolute - def get_unit_main_wave(self, n_chans=1): + def get_unit_main_wave(self, n_chans=4): ch = self.get_unit_main_chans(n_chans=n_chans)[0] return self.waveforms[:, :, ch] @@ -549,7 +581,7 @@ def get_all_unit_widths(self): def get_unit_widths(self): widths = self.get_all_unit_widths() return np.median(widths), np.std(widths) - + def support_vector_ms(starts, len_samples, all_units, win_size=10, s_f=30000, history_steps=1, step_size=1, @@ -986,27 +1018,49 @@ def collect_frames_array(starts, span, s_f, kwd_file, recs_list, chan_list): logger.info('Done collecting') return np.concatenate(all_frames_array, axis=0) - -def collect_frames_fast(starts, span, s_f, kwd_file, recs_list, chan_list): - recs = np.unique(recs_list) - logger.info('Collecting {} recs...'.format(recs.size)) - all_frames_list = [] - for i_rec, rec in tqdm(enumerate(recs)): - starts_from_rec = starts[recs_list == rec] - logger.info("Rec {0}, {1} events ...".format(rec, starts_from_rec.size)) - - h5_dset = h5f.get_data_set(kwd_file, rec) - n_samples = h5_dset.shape[0] - - valid_starts = starts_from_rec[(starts_from_rec > 0) - & (starts_from_rec + span < n_samples)] - if valid_starts.size < starts_from_rec.size: - logger.warn('Some frames were out of bounds and will be discarded') - logger.warn('will collect only {0} events...'.format(valid_starts.size)) - - this_rec_spikes = st.repeated_slice(h5_dset, valid_starts, span, chan_list) - all_frames_list.append(this_rec_spikes) - - logger.info('Done collecting') - return np.concatenate(all_frames_list, axis=0) - +# @h5f.h5_decorator(default_mode='r') +# def collect_frames_fast(kwd_file, recs_list, starts, span, chan_list): +# recs = np.unique(recs_list) +# all_frames_list = [] +# for i_rec, rec in tqdm(enumerate(recs), total=recs.size): +# starts_from_rec = starts[recs_list == rec] +# dset = h5f.get_data_set(kwd_file, rec) +# n_samples = dset.shape[0] +# valid_starts = starts_from_rec[(starts_from_rec > 0) +# & (starts_from_rec + span < n_samples)] +# if valid_starts.size < starts_from_rec.size: +# logger.warn('Some frames were out of bounds and will be discarded') +# logger.warn('will collect only {0} events...'.format( +# valid_starts.size)) + +# # get the dataset slices for only the channel list +# this_rec_frames = get_slice_array(dset, valid_starts, span, chan_list) +# all_frames_list.append(this_rec_frames) + +# try: +# all_frames_array = np.concatenate(all_frames_list, axis=0) +# except ValueError: +# raise +# # logger.warn('Failed to collect stream frames, return is nan array') +# # zero_dset_shape = get_data_set(kwd_file, rec).shape +# # all_frames_array = np.empty([1, *zero_dset_shape]) +# # all_frames_array[:] = np.nan +# return all_frames_array + +# def get_slice_array(dset: np.ndarray, starts: np.ndarray, span: np.int, chan_list) -> np.ndarray: +# n_slices = starts.size +# #n_chan = dset.shape[1] +# n_chan = chan_list.size +# chan_list = np.array(chan_list) +# slices_array = np.zeros([n_slices, span, n_chan]) +# #logger.info('dset {}, starts {}, span {}, chan_list {}'.format(dset.shape, starts, span, chan_list)) +# #logger.info('starts {}'.format(starts.dtype)) +# for i, start in enumerate(starts.astype(np.int64)): +# #logger.info('start {}'.format(start.shape)) +# #logger.info('chan_list {}'.format(chan_list)) +# #logger.info('i {}'.format(i)) +# #logger.info('span'.format(span)) +# #start = np.int(start) +# #aux = dset[5: 5 + span, chan_list] +# slices_array[i, :, :] = dset[start: start + span, chan_list] +# return slices_array \ No newline at end of file diff --git a/swissknife/hilevel/metrics.py b/swissknife/hilevel/metrics.py index 43e4c26..bf92de1 100755 --- a/swissknife/hilevel/metrics.py +++ b/swissknife/hilevel/metrics.py @@ -2,6 +2,7 @@ import numpy as np import matplotlib.pyplot as plt import pandas as pd +import scipy from tqdm import tqdm from swissknife.streamtools import spectral as sp @@ -70,6 +71,43 @@ def rms_slices(s_1, s_2): return rms_slice, rms_total +def equal_shaped(x_in: np.array, y_in: np.array, warp='nowarp') -> tuple: + # return the two largest possible x, y with same dimension 2 + xy = [x_in, y_in] + lengths = np.array([a.shape[-1] for a in xy]) + #len_diff = np.diff(lengths) + # if(np.abs(len_diff)>2): + # logger.warning('Spectrograms differ in {} ms'.format(len_diff)) + if warp is 'nowarp': + shorter_t = np.min(lengths) + x = x_in[:, :shorter_t] + y = y_in[:, :shorter_t] + + elif warp is 'median': + # warp to shorter time array + shorter_t = np.min(lengths) + sorted_l = np.argsort(lengths) + + x_short = xy[sorted_l[0]] + x_long = xy[sorted_l[1]] + + longer_t = x_long.shape[-1] + t_long_warped = np.arange(shorter_t)*longer_t/shorter_t + t_slice_long = t_long_warped.astype(np.int) + x = x_short + y = x_long[:, t_slice_long] + else: + raise NotImplementedError('Dont know who to warp [{}]'.format(warp)) + + return x, y + +def normalize_spec(sx): + sx -= np.amin(sx) + sx_max = np.amax(sx) + + sx /= sx_max + return sx + def compare_spectra_old(x, y, s_f=30000, n_perseg=1024, step_s=0.001, db_cut=65, f_min=300, f_max=7500, log=False, plots=None): @@ -105,27 +143,35 @@ def compare_spectra_old(x, y, s_f=30000, n_perseg=1024, step_s=0.001, db_cut=65, return rms_slices(short, long[:, :shorter_t]) -def compare_spectra(x, y, s_f=30000, n_perseg=1024, step_s=0.001, db_cut=65, f_min=0, f_max=12000, plots=False): - # make sure sizes are right - xy = [x, y] - lengths = np.array([a.shape[-1] for a in xy]) - len_diff = np.diff(lengths) - # if(np.abs(len_diff)>2): - # logger.warning('Spectrograms differ in {} ms'.format(len_diff)) - shorter_t = np.min(lengths) - sorted_l = np.argsort(lengths) +def compare_spectra(x, y, s_f=30000, n_perseg=1024, step_s=0.001, db_cut=75, f_min=0, f_max=12000, +plots=False, warp='nowarp'): + + if warp=='nowarp': + # make sure sizes are right + xy = [x, y] + lengths = np.array([a.shape[-1] for a in xy]) + len_diff = np.diff(lengths) + # if(np.abs(len_diff)>2): + # logger.warning('Spectrograms differ in {} ms'.format(len_diff)) + shorter_t = np.min(lengths) + sorted_l = np.argsort(lengths) - short = xy[sorted_l[0]] - long = xy[sorted_l[1]] + short = xy[sorted_l[0]] + long = xy[sorted_l[1]] + x = x[: shorter_t] + y = y[: shorter_t] - fx, tx, sx = sp.pretty_spectrogram(normalize(x[:shorter_t]), s_f, fft_size=n_perseg, log=True, + fx, tx, sx = sp.pretty_spectrogram(normalize(x), s_f, fft_size=n_perseg, log=True, step_size=int(s_f * step_s), db_cut=db_cut, f_min=f_min, f_max=f_max, window=('gaussian', 120)) - fy, ty, sy = sp.pretty_spectrogram(normalize(y[:shorter_t]), s_f, fft_size=n_perseg, log=True, + fy, ty, sy = sp.pretty_spectrogram(normalize(y), s_f, fft_size=n_perseg, log=True, step_size=int(s_f * step_s), db_cut=db_cut, f_min=f_min, f_max=f_max, window=('gaussian', 120)) + if warp == 'median': + sx, sy = equal_shaped(sx, sy, warp='median') + if plots: plt.imshow(((sx))[::-1], aspect='auto', cmap='inferno') plt.grid(False) @@ -146,10 +192,23 @@ def compare_spectra(x, y, s_f=30000, n_perseg=1024, step_s=0.001, db_cut=65, f_m sy -= np.amin(sy) sy /= np.amax(sx_max) - rms = np.linalg.norm(sx - sy) / np.sqrt(sx.size) + # deal with zeros to compute the spectrogram correlations + f_bin, t_bin = sx.shape + + #zero_x = np.where((sx.sum(axis=0)<10) & (sy.sum(axis=0) < 10) )[0] + zero_x = np.where(sx.sum(axis=0)<1)[0] + + mu = 0.01 + epsilon = mu*8e-17 + + x_jitter = np.random.normal(mu, epsilon, (f_bin, zero_x.size)) + y_jitter = np.random.normal(mu, epsilon, (f_bin, zero_x.size)) + + sx[:, zero_x] = x_jitter + sy[:, zero_x] = y_jitter - rho = np.corrcoef(sx, sy) - return rms, sx, sy + rxy = np.array([scipy.stats.pearsonr(i, j)[0] for i,j in zip(sx.T, sy.T)]) + return rxy, sx, sy def mot_scores(mot_id, Y, Z, mod_pred, sess_data, win_samples=64, other_pd=None): @@ -281,7 +340,7 @@ def all_self_scores(one_pd, other_pd, pass_thru=[]): logger = logging.getLogger() all_mots = one_pd['m_id'].tolist() - logger.info('Found {} mots'.format(len(all_mots))) + logger.info('Found {} mots your mom'.format(len(all_mots))) all_scores = [] logger.disabled = True @@ -291,8 +350,8 @@ def all_self_scores(one_pd, other_pd, pass_thru=[]): one_pd['syn_song'].tolist(), one_pd['raw_song'].tolist())), total=len(one_pd['m_id'].tolist())): - rms_raw = compare_spectra(x, x_raw, n_perseg=64, db_cut=65)[0] - rms_syn = compare_spectra(x, x_syn, n_perseg=64, db_cut=55)[0] + rms_raw, sxx_neu, sxx_raw = compare_spectra(x, x_raw, n_perseg=64, db_cut=55) + rms_syn, _, sxx_syn = compare_spectra(x, x_syn, n_perseg=64, db_cut=55) rms_syn_raw = compare_spectra(x_syn, x_raw, n_perseg=64, db_cut=55)[0] rms_con = np.array( list(map(lambda z: compare_spectra(x, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist()))) @@ -306,11 +365,14 @@ def all_self_scores(one_pd, other_pd, pass_thru=[]): rms_bos_bos = np.array(bos_bos) cross_mot_id = other_pd['m_id'].tolist() - all_scores.append([m_id, rms_raw, rms_syn, rms_syn_raw, rms_con, rms_syn_con, rms_bos_con, rms_bos_bos, cross_mot_id]) + all_scores.append([m_id, rms_raw, rms_syn, rms_syn_raw, rms_con, + rms_syn_con, rms_bos_con, rms_bos_bos, cross_mot_id, + sxx_raw, sxx_neu, sxx_syn]) logger.disabled = False - headers = ['m_id', 'rms_raw', 'rms_syn', 'rms_syn_raw', 'rms_con', 'rms_syn_con', 'rms_bos_con', 'rms_bos_bos', 'vs_id'] - + headers = ['m_id', 'rms_raw', 'rms_syn', 'rms_syn_raw', + 'rms_con', 'rms_syn_con', 'rms_bos_con', 'rms_bos_bos', 'vs_id', + 'sxx_raw', 'sxx_neu', 'sxx_syn'] pd_all_scores = pd.DataFrame(all_scores, columns=headers) # append passtrhu fields diff --git a/swissknife/hilevel/plotin.py b/swissknife/hilevel/plotin.py new file mode 100644 index 0000000..bb5afd2 --- /dev/null +++ b/swissknife/hilevel/plotin.py @@ -0,0 +1,44 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +def plot_line_ci(x, x_var, color='r', label='arr', ls='-', lw=1, ax=None, + alpha=0.04, t=np.empty(0)): + + if ax is None: + f, ax = plt.subplots() + + if t.size==0: + t = np.arange(x.size) + ax.plot(t, x, lw=lw, ls=ls, color=color, alpha=1, label=label) + ax.fill_between(t, x + x_var, x - x_var, color=color, + alpha=0.07) + return ax + + +# reads kai +# /mnt/cube/kai/results/spectrogram prediction model/mel/Resub_2018/lstm/ffnn_lstm.p' +# type of file +# adds two indices for easily grouping by trial and averagin by time +def load_kai_pd(pickle_path): + kai_pd = pd.read_pickle(pickle_path) + kai_pd['t'] = kai_pd['time']*30000 + kai_pd['t'] = kai_pd['t'].apply(np.int) + kai_pd['idx'] = kai_pd.index + kai_pd.set_index(['t', 'idx'],inplace=True) + kai_pd.sort_index(inplace=True) + kai_pd.sort_values(['t', 'idx']) + return kai_pd + +def pd_to_arrays(k_pd, measure='correlation', model='LSTM'): + # return an array with mean, std values + # t in ms + sel_filter = k_pd['model']==model + + pd_grouped = k_pd.loc[sel_filter, :].groupby('t') + + t = pd_grouped['time'].mean().values*1000 + avg = pd_grouped[measure].mean() + err = pd_grouped[measure].std() + + return t.astype(np.int), avg, err \ No newline at end of file