Skip to content

Commit

Permalink
before sending manuscript to Tim
Browse files Browse the repository at this point in the history
  • Loading branch information
zekearneodo committed Nov 7, 2018
1 parent 1280c1f commit 85494cb
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 20 deletions.
18 changes: 18 additions & 0 deletions swissknife/bci/core/expstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
# for more than structure
import numpy as np
import pandas as pd
import yaml
from numpy.lib import recfunctions as rf

Expand Down Expand Up @@ -175,6 +176,15 @@ def mkdir_p(path):
else:
raise

def sub_dirs(path):
return [d for d in glob.glob(os.path.join(path, '*')) if os.path.isdir(d)]


def list_birds(folder, breed='z'):
all_dirs = [os.path.split(d)[-1] for d in sub_dirs(folder)]
all_birds = [b for b in all_dirs if b.startswith(breed)]
all_birds.sort()
return all_birds

def list_sessions(bird, experiment_folder=None, location='ss'):
fn = file_names(bird, experiment_folder=experiment_folder)
Expand All @@ -183,6 +193,7 @@ def list_sessions(bird, experiment_folder=None, location='ss'):
return sessions_bird



def list_raw_sessions(bird, sess_day=None, depth='', experiment_folder=None, location='raw'):
all_sessions = list_sessions(bird, experiment_folder=experiment_folder, location=location)
if sess_day is not None:
Expand All @@ -193,6 +204,13 @@ def list_raw_sessions(bird, sess_day=None, depth='', experiment_folder=None, loc
return all_sessions, all_depths


def get_sessions_info_pd(breed='z', location='rw'):
folder = file_names('')['folders'][location]
info_pd = pd.DataFrame(list_birds(folder, breed=breed), columns=['bird'])
info_pd['sessions'] = info_pd['bird'].apply(lambda x: list_sessions(x, location=location))
return info_pd


# Experiment structure
def get_parameters(bird, sess, rec=0, experiment_folder=None, location='ss'):
fn = file_names(bird, sess, rec, experiment_folder=experiment_folder)
Expand Down
2 changes: 2 additions & 0 deletions swissknife/bci/core/file/h5_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,10 @@ def get_rec_sizes(kwd_file):

@h5_wrap
def get_corresponding_rec(kwd_file, sample):
# returns the order of rec in which this position belongs (not the rec_id)
rec_starts = get_rec_starts(kwd_file)
rec_index = np.min(np.where(sample > rec_starts))
return rec_index


@h5_wrap
Expand Down
11 changes: 6 additions & 5 deletions swissknife/bci/core/kwik_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,12 @@ def get_rec_list(k_file):
@h5f.h5_wrap
def rec_start_array(kwik):
rec_list = list(map(int, get_rec_list(kwik)))
rec_array = np.arange(max(rec_list) + 1)
start_array = np.zeros_like(rec_array)
for i_rec in rec_list:
start_array[i_rec] = get_rec_start_sample(kwik, i_rec)
return start_array
# rec_array = np.arange(max(rec_list) + 1)
# start_array = np.zeros_like(rec_array)
start_array = [get_rec_start_sample(kwik, rec_id) for rec_id in rec_list]
# for i_rec, rec_id in enumerate(rec_list):
# start_array[i_rec] = get_rec_start_sample(kwik, rec_id)
return np.array(start_array)


@h5f.h5_wrap
Expand Down
57 changes: 48 additions & 9 deletions swissknife/bci/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ def __init__(self, clu, group=0, kwik_file=None, sort=0):
self.sampling_rate = None

self.kwd_file = None
self.all_waveforms = None # all of the waveforms
self.n_waveforms = 1000 #sample of waveforms to show/compute
self.waveforms = None
self.avg_waveform = None
self.main_chan = None
Expand All @@ -277,10 +279,12 @@ def get_time_stamps(self):
r_path = "/channel_groups/{0:d}/spikes/recording".format(self.group)

dtype = self.kwik_file[t_path].dtype
# time samples are relative to the beginning of the corresponding rec
time_samples = np.array(self.kwik_file[t_path][self.kwik_file[clu_path][:] == self.clu],
dtype=np.dtype(dtype))

dtype = self.kwik_file[r_path].dtype
# recordings ids (as in the key)
recordings = np.array(self.kwik_file[r_path][self.kwik_file[clu_path][:] == self.clu],
dtype=np.dtype(dtype))

Expand Down Expand Up @@ -376,7 +380,6 @@ def get_isi_dist(self, bin_size_ms=1, max_t=100):
two_side_hist = np.concatenate([hist[::-1], hist[1:]])
return two_side_hist, two_side_bins


def get_folder(self):
return os.path.split(os.path.abspath(self.kwik_file.filename))[0]

Expand Down Expand Up @@ -407,22 +410,23 @@ def save_unit_spikes(self):
'unit_{}_{:03d}.npy'.format(self.group,
self.clu))
logger.info('Saving unit {0} in file {1}'.format(self.clu, file_path))
np.save(file_path, self.waveforms)
np.save(file_path, self.all_waveforms)
par_path = os.path.join(file_folder,
'unit_{}_{:03d}.par.pickle'.format(self.group,
self.clu))
pickle.dump(self.waveform_pars, open(par_path, 'wb'))

def load_unit_spikes(self):
logger.info('will try to load previous unit files')
logger.debug('will try to load previous unit files')
# careful, loads the last saved
folder = self.get_folder()
f_name = 'unit_{}_{:03d}.npy'.format(self.group, self.clu)
p_name = 'unit_{}_{:03d}.par.pickle'.format(self.group, self.clu)
self.waveform_pars = pickle.load(open(os.path.join(folder, 'unit_waveforms', p_name),
'rb'))
self.waveforms = np.load(os.path.join(folder, 'unit_waveforms', f_name))
return self.waveforms
self.all_waveforms = np.load(os.path.join(folder, 'unit_waveforms', f_name), mmap_mode='r')

return self.all_waveforms

def get_principal_channels(self, projectors=4):
kilo_path = self.get_kilo_folder()
Expand Down Expand Up @@ -463,22 +467,26 @@ def get_unit_spikes(self, before=20, after=20, only_principal=False):

self.waveform_pars = {'before': before,
'after': after,
'chan_list': chan_list}
'chan_list': np.array(chan_list)}

self.waveforms = collect_frames_array(valid_times - before,
self.all_waveforms = collect_frames_fast(valid_times - before,
before + after,
s_f,
self.get_kwd_path(),
valid_recs,
chan_list)
np.array(chan_list))
return self.waveforms

def load_all_waveforms(self):
folder = self.get_folder()
f_name = 'unit_{:03d}.npy'.format(self.clu)
return np.load(os.path.join(folder, 'unit_waveforms', f_name))
return np.load(os.path.join(folder, 'unit_waveforms', f_name), mmap_mode='r')

def set_n_waveforms(self, n_waveforms):
self.n_waveforms = n_waveforms

def get_waveforms(self, before=20, after=20, only_principal=False, force=False):

try:
logger.info('Trying to load waveforms file')
assert force is False
Expand All @@ -489,6 +497,12 @@ def get_waveforms(self, before=20, after=20, only_principal=False, force=False):
only_principal=only_principal)
logger.info('will save the spikes for the nest time around')
self.save_unit_spikes()
# all waveforms were loaded into self.all_waveforms.
# now we want to make a sample fo them in self.waveforms, to show and compute metrics
self.n_waveforms = min(self.n_waveforms, self.all_waveforms.shape[0])
waveform_samples = np.random.choice(self.all_waveforms.shape[0], self.n_waveforms,
replace=False)
self.waveforms = self.all_waveforms[waveform_samples, :, :]
return self.waveforms

def get_avg_wave(self):
Expand Down Expand Up @@ -971,3 +985,28 @@ def collect_frames_array(starts, span, s_f, kwd_file, recs_list, chan_list):
all_frames_array.append(rec_frames)
logger.info('Done collecting')
return np.concatenate(all_frames_array, axis=0)


def collect_frames_fast(starts, span, s_f, kwd_file, recs_list, chan_list):
recs = np.unique(recs_list)
logger.info('Collecting {} recs...'.format(recs.size))
all_frames_list = []
for i_rec, rec in tqdm(enumerate(recs)):
starts_from_rec = starts[recs_list == rec]
logger.info("Rec {0}, {1} events ...".format(rec, starts_from_rec.size))

h5_dset = h5f.get_data_set(kwd_file, rec)
n_samples = h5_dset.shape[0]

valid_starts = starts_from_rec[(starts_from_rec > 0)
& (starts_from_rec + span < n_samples)]
if valid_starts.size < starts_from_rec.size:
logger.warn('Some frames were out of bounds and will be discarded')
logger.warn('will collect only {0} events...'.format(valid_starts.size))

this_rec_spikes = st.repeated_slice(h5_dset, valid_starts, span, chan_list)
all_frames_list.append(this_rec_spikes)

logger.info('Done collecting')
return np.concatenate(all_frames_list, axis=0)

2 changes: 1 addition & 1 deletion swissknife/hilevel/ffnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def _build_graph(self):

return x, y, y_, r_loss, c_solver, summaries, global_step

def train(self, X, Y, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=True):
def train(self, X, Y, max_iter=np.inf, max_epochs=np.inf, cross_validate=True, verbose=False):
# set aside train/validation split
x_train, x_val, y_train, y_val = train_test_split(X, Y, test_size=self.validation_split)
r_losses = []
Expand Down
10 changes: 5 additions & 5 deletions swissknife/hilevel/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def all_mot_decoded_pcwise(y, z, y_p, sess_data):
fit_target = 'dyn'

s_f = int(sess_data.s_f)
logger.info('Getting streams and reconstructions of {} motifs, target fit is {}'.format(mot_ids.size, fit_target))
logger.info('Getting streams and reconstructions of {} motifs, target fit is {}, dim {}'.format(mot_ids.size, fit_target, y_p.shape[-1]))

all_decoded = []
for i, m_id in tqdm(enumerate(mot_ids), total=mot_ids.size):
Expand Down Expand Up @@ -291,15 +291,15 @@ def all_self_scores(one_pd, other_pd, pass_thru=[]):
one_pd['syn_song'].tolist(),
one_pd['raw_song'].tolist())),
total=len(one_pd['m_id'].tolist())):
rms_raw = compare_spectra(x, x_raw, n_perseg=64, db_cut=55)[0]
rms_raw = compare_spectra(x, x_raw, n_perseg=64, db_cut=65)[0]
rms_syn = compare_spectra(x, x_syn, n_perseg=64, db_cut=55)[0]
rms_syn_raw = compare_spectra(x_syn, x_raw, n_perseg=64, db_cut=55)[0]
rms_con = np.array(
list(map(lambda z: compare_spectra(x, z, n_perseg=128, db_cut=90)[0], other_pd['x'].tolist())))
list(map(lambda z: compare_spectra(x, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist())))
rms_syn_con = np.array(
list(map(lambda z: compare_spectra(x_syn, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist())))
list(map(lambda z: compare_spectra(x_syn, z, n_perseg=128, db_cut=70)[0], other_pd['x'].tolist())))
rms_bos_con = np.array(
list(map(lambda z: compare_spectra(x_raw, z, n_perseg=128, db_cut=90)[0], other_pd['x'].tolist())))
list(map(lambda z: compare_spectra(x_raw, z, n_perseg=128, db_cut=80)[0], other_pd['x'].tolist())))

bos_bos = [compare_spectra(x_raw, y, n_perseg=128,
db_cut=80)[0] for j,y in enumerate(all_raw) if not j==i]
Expand Down
10 changes: 10 additions & 0 deletions swissknife/streamtools/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import scipy.signal as sg
from matplotlib import pyplot as plt
from numba import jit

from swissknife.h5tools import tables as h5t

Expand Down Expand Up @@ -104,6 +105,15 @@ def sum_frames(frames_list):
all_avg = all_frames_array.mean(axis=0)
return all_avg

@jit
def repeated_slice(big_array: np.array, starts: np.array, span: int, chan_list: np.array) -> np.array:
n_starts = starts.size
n_cols = chan_list.size

slice_stack = np.empty([n_starts, span, n_cols])
for i, start in enumerate(range(n_starts)):
slice_stack[i, :, :] = big_array[start: start + span, chan_list]
return slice_stack

class WavData2:
# same as wavdata, but streams are read in columns into an N_samp X N_ch array (one channel = one column)
Expand Down

0 comments on commit 85494cb

Please sign in to comment.