From 85494cbff37f8dd1e0b71475f93e8ebe68dea7f2 Mon Sep 17 00:00:00 2001 From: Zeke Arneodo Date: Wed, 7 Nov 2018 09:48:11 -0800 Subject: [PATCH] before sending manuscript to Tim --- swissknife/bci/core/expstruct.py | 18 ++++++++ swissknife/bci/core/file/h5_functions.py | 2 + swissknife/bci/core/kwik_functions.py | 11 ++--- swissknife/bci/units.py | 57 ++++++++++++++++++++---- swissknife/hilevel/ffnn.py | 2 +- swissknife/hilevel/metrics.py | 10 ++--- swissknife/streamtools/streams.py | 10 +++++ 7 files changed, 90 insertions(+), 20 deletions(-) diff --git a/swissknife/bci/core/expstruct.py b/swissknife/bci/core/expstruct.py index b8e5d33..d4de890 100755 --- a/swissknife/bci/core/expstruct.py +++ b/swissknife/bci/core/expstruct.py @@ -7,6 +7,7 @@ import h5py # for more than structure import numpy as np +import pandas as pd import yaml from numpy.lib import recfunctions as rf @@ -175,6 +176,15 @@ def mkdir_p(path): else: raise +def sub_dirs(path): + return [d for d in glob.glob(os.path.join(path, '*')) if os.path.isdir(d)] + + +def list_birds(folder, breed='z'): + all_dirs = [os.path.split(d)[-1] for d in sub_dirs(folder)] + all_birds = [b for b in all_dirs if b.startswith(breed)] + all_birds.sort() + return all_birds def list_sessions(bird, experiment_folder=None, location='ss'): fn = file_names(bird, experiment_folder=experiment_folder) @@ -183,6 +193,7 @@ def list_sessions(bird, experiment_folder=None, location='ss'): return sessions_bird + def list_raw_sessions(bird, sess_day=None, depth='', experiment_folder=None, location='raw'): all_sessions = list_sessions(bird, experiment_folder=experiment_folder, location=location) if sess_day is not None: @@ -193,6 +204,13 @@ def list_raw_sessions(bird, sess_day=None, depth='', experiment_folder=None, loc return all_sessions, all_depths +def get_sessions_info_pd(breed='z', location='rw'): + folder = file_names('')['folders'][location] + info_pd = pd.DataFrame(list_birds(folder, breed=breed), columns=['bird']) + info_pd['sessions'] = info_pd['bird'].apply(lambda x: list_sessions(x, location=location)) + return info_pd + + # Experiment structure def get_parameters(bird, sess, rec=0, experiment_folder=None, location='ss'): fn = file_names(bird, sess, rec, experiment_folder=experiment_folder) diff --git a/swissknife/bci/core/file/h5_functions.py b/swissknife/bci/core/file/h5_functions.py index 03ead41..9b10469 100755 --- a/swissknife/bci/core/file/h5_functions.py +++ b/swissknife/bci/core/file/h5_functions.py @@ -119,8 +119,10 @@ def get_rec_sizes(kwd_file): @h5_wrap def get_corresponding_rec(kwd_file, sample): + # returns the order of rec in which this position belongs (not the rec_id) rec_starts = get_rec_starts(kwd_file) rec_index = np.min(np.where(sample > rec_starts)) + return rec_index @h5_wrap diff --git a/swissknife/bci/core/kwik_functions.py b/swissknife/bci/core/kwik_functions.py index 3fd5132..ef6fbd6 100755 --- a/swissknife/bci/core/kwik_functions.py +++ b/swissknife/bci/core/kwik_functions.py @@ -162,11 +162,12 @@ def get_rec_list(k_file): @h5f.h5_wrap def rec_start_array(kwik): rec_list = list(map(int, get_rec_list(kwik))) - rec_array = np.arange(max(rec_list) + 1) - start_array = np.zeros_like(rec_array) - for i_rec in rec_list: - start_array[i_rec] = get_rec_start_sample(kwik, i_rec) - return start_array + # rec_array = np.arange(max(rec_list) + 1) + # start_array = np.zeros_like(rec_array) + start_array = [get_rec_start_sample(kwik, rec_id) for rec_id in rec_list] + # for i_rec, rec_id in enumerate(rec_list): + # start_array[i_rec] = get_rec_start_sample(kwik, rec_id) + return np.array(start_array) @h5f.h5_wrap diff --git a/swissknife/bci/units.py b/swissknife/bci/units.py index f4f26a7..1c69b7c 100755 --- a/swissknife/bci/units.py +++ b/swissknife/bci/units.py @@ -255,6 +255,8 @@ def __init__(self, clu, group=0, kwik_file=None, sort=0): self.sampling_rate = None self.kwd_file = None + self.all_waveforms = None # all of the waveforms + self.n_waveforms = 1000 #sample of waveforms to show/compute self.waveforms = None self.avg_waveform = None self.main_chan = None @@ -277,10 +279,12 @@ def get_time_stamps(self): r_path = "/channel_groups/{0:d}/spikes/recording".format(self.group) dtype = self.kwik_file[t_path].dtype + # time samples are relative to the beginning of the corresponding rec time_samples = np.array(self.kwik_file[t_path][self.kwik_file[clu_path][:] == self.clu], dtype=np.dtype(dtype)) dtype = self.kwik_file[r_path].dtype + # recordings ids (as in the key) recordings = np.array(self.kwik_file[r_path][self.kwik_file[clu_path][:] == self.clu], dtype=np.dtype(dtype)) @@ -376,7 +380,6 @@ def get_isi_dist(self, bin_size_ms=1, max_t=100): two_side_hist = np.concatenate([hist[::-1], hist[1:]]) return two_side_hist, two_side_bins - def get_folder(self): return os.path.split(os.path.abspath(self.kwik_file.filename))[0] @@ -407,22 +410,23 @@ def save_unit_spikes(self): 'unit_{}_{:03d}.npy'.format(self.group, self.clu)) logger.info('Saving unit {0} in file {1}'.format(self.clu, file_path)) - np.save(file_path, self.waveforms) + np.save(file_path, self.all_waveforms) par_path = os.path.join(file_folder, 'unit_{}_{:03d}.par.pickle'.format(self.group, self.clu)) pickle.dump(self.waveform_pars, open(par_path, 'wb')) def load_unit_spikes(self): - logger.info('will try to load previous unit files') + logger.debug('will try to load previous unit files') # careful, loads the last saved folder = self.get_folder() f_name = 'unit_{}_{:03d}.npy'.format(self.group, self.clu) p_name = 'unit_{}_{:03d}.par.pickle'.format(self.group, self.clu) self.waveform_pars = pickle.load(open(os.path.join(folder, 'unit_waveforms', p_name), 'rb')) - self.waveforms = np.load(os.path.join(folder, 'unit_waveforms', f_name)) - return self.waveforms + self.all_waveforms = np.load(os.path.join(folder, 'unit_waveforms', f_name), mmap_mode='r') + + return self.all_waveforms def get_principal_channels(self, projectors=4): kilo_path = self.get_kilo_folder() @@ -463,22 +467,26 @@ def get_unit_spikes(self, before=20, after=20, only_principal=False): self.waveform_pars = {'before': before, 'after': after, - 'chan_list': chan_list} + 'chan_list': np.array(chan_list)} - self.waveforms = collect_frames_array(valid_times - before, + self.all_waveforms = collect_frames_fast(valid_times - before, before + after, s_f, self.get_kwd_path(), valid_recs, - chan_list) + np.array(chan_list)) return self.waveforms def load_all_waveforms(self): folder = self.get_folder() f_name = 'unit_{:03d}.npy'.format(self.clu) - return np.load(os.path.join(folder, 'unit_waveforms', f_name)) + return np.load(os.path.join(folder, 'unit_waveforms', f_name), mmap_mode='r') + + def set_n_waveforms(self, n_waveforms): + self.n_waveforms = n_waveforms def get_waveforms(self, before=20, after=20, only_principal=False, force=False): + try: logger.info('Trying to load waveforms file') assert force is False @@ -489,6 +497,12 @@ def get_waveforms(self, before=20, after=20, only_principal=False, force=False): only_principal=only_principal) logger.info('will save the spikes for the nest time around') self.save_unit_spikes() + # all waveforms were loaded into self.all_waveforms. + # now we want to make a sample fo them in self.waveforms, to show and compute metrics + self.n_waveforms = min(self.n_waveforms, self.all_waveforms.shape[0]) + waveform_samples = np.random.choice(self.all_waveforms.shape[0], self.n_waveforms, + replace=False) + self.waveforms = self.all_waveforms[waveform_samples, :, :] return self.waveforms def get_avg_wave(self): @@ -971,3 +985,28 @@ def collect_frames_array(starts, span, s_f, kwd_file, recs_list, chan_list): all_frames_array.append(rec_frames) logger.info('Done collecting') return np.concatenate(all_frames_array, axis=0) + + +def collect_frames_fast(starts, span, s_f, kwd_file, recs_list, chan_list): + recs = np.unique(recs_list) + logger.info('Collecting {} recs...'.format(recs.size)) + all_frames_list = [] + for i_rec, rec in tqdm(enumerate(recs)): + starts_from_rec = starts[recs_list == rec] + logger.info("Rec {0}, {1} events ...".format(rec, starts_from_rec.size)) + + h5_dset = h5f.get_data_set(kwd_file, rec) + n_samples = h5_dset.shape[0] + + valid_starts = starts_from_rec[(starts_from_rec > 0) + & (starts_from_rec + span < n_samples)] + if valid_starts.size < starts_from_rec.size: + logger.warn('Some frames were out of bounds and will be discarded') + logger.warn('will collect only {0} events...'.format(valid_starts.size)) + + this_rec_spikes = st.repeated_slice(h5_dset, valid_starts, span, chan_list) + all_frames_list.append(this_rec_spikes) + + logger.info('Done collecting') + return np.concatenate(all_frames_list, axis=0) + diff --git a/swissknife/hilevel/ffnn.py b/swissknife/hilevel/ffnn.py index f4771d2..77798f7 100755 --- a/swissknife/hilevel/ffnn.py +++ b/swissknife/hilevel/ffnn.py @@ -179,7 +179,7 @@ def _build_graph(self): return x, y, y_, r_loss, c_solver, summaries, global_step - def train(self, X, Y, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=True): + def train(self, X, Y, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=False): # set aside train/validation split x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=self.validation_split) r_losses = [] diff --git a/swissknife/hilevel/metrics.py b/swissknife/hilevel/metrics.py index b3f060b..43e4c26 100755 --- a/swissknife/hilevel/metrics.py +++ b/swissknife/hilevel/metrics.py @@ -249,7 +249,7 @@ def all_mot_decoded_pcwise(y, z, y_p, sess_data): fit_target = 'dyn' s_f = int(sess_data.s_f) - logger.info('Getting streams and reconstructions of {} motifs, target fit is {}'.format(mot_ids.size, fit_target)) + logger.info('Getting streams and reconstructions of {} motifs, target fit is {}, dim {}'.format(mot_ids.size, fit_target, y_p.shape[-1])) all_decoded = [] for i, m_id in tqdm(enumerate(mot_ids), total=mot_ids.size): @@ -291,15 +291,15 @@ def all_self_scores(one_pd, other_pd, pass_thru=[]): one_pd['syn_song'].tolist(), one_pd['raw_song'].tolist())), total=len(one_pd['m_id'].tolist())): - rms_raw = compare_spectra(x, x_raw, n_perseg=64, db_cut=55)[0] + rms_raw = compare_spectra(x, x_raw, n_perseg=64, db_cut=65)[0] rms_syn = compare_spectra(x, x_syn, n_perseg=64, db_cut=55)[0] rms_syn_raw = compare_spectra(x_syn, x_raw, n_perseg=64, db_cut=55)[0] rms_con = np.array( - list(map(lambda z: compare_spectra(x, z, n_perseg=128, db_cut=90)[0], other_pd['x'].tolist()))) + list(map(lambda z: compare_spectra(x, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist()))) rms_syn_con = np.array( - list(map(lambda z: compare_spectra(x_syn, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist()))) + list(map(lambda z: compare_spectra(x_syn, z, n_perseg=128, db_cut=70)[0], other_pd['x'].tolist()))) rms_bos_con = np.array( - list(map(lambda z: compare_spectra(x_raw, z, n_perseg=128, db_cut=90)[0], other_pd['x'].tolist()))) + list(map(lambda z: compare_spectra(x_raw, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist()))) bos_bos = [compare_spectra(x_raw, y, n_perseg=128, db_cut=80)[0] for j,y in enumerate(all_raw) if not j==i] diff --git a/swissknife/streamtools/streams.py b/swissknife/streamtools/streams.py index 292791c..d67408e 100755 --- a/swissknife/streamtools/streams.py +++ b/swissknife/streamtools/streams.py @@ -10,6 +10,7 @@ import numpy as np import scipy.signal as sg from matplotlib import pyplot as plt +from numba import jit from swissknife.h5tools import tables as h5t @@ -104,6 +105,15 @@ def sum_frames(frames_list): all_avg = all_frames_array.mean(axis=0) return all_avg +@jit +def repeated_slice(big_array: np.array, starts: np.array, span: int, chan_list: np.array) -> np.array: + n_starts = starts.size + n_cols = chan_list.size + + slice_stack = np.empty([n_starts, span, n_cols]) + for i, start in enumerate(range(n_starts)): + slice_stack[i, :, :] = big_array[start: start + span, chan_list] + return slice_stack class WavData2: # same as wavdata, but streams are read in columns into an N_samp X N_ch array (one channel = one column)