Skip to content

Commit

Permalink
Merge pull request #167 from DradeAW/master
Browse files Browse the repository at this point in the history
Load templates as memmap rather than in memory
  • Loading branch information
alejoe91 authored Apr 6, 2024
2 parents 4c1324b + f4c7dd8 commit 57483d7
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 18 deletions.
6 changes: 3 additions & 3 deletions docs/generate_templates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,9 @@ Drifting parameters summary
max_drift: 100 # max distance from the initial and final cell position
min_drift: 30 # min distance from the initial and final cell position
drift_steps: 50 # number of drift steps
drift_x_lim: [-10, 10] # drift limits in the x-direction
drift_y_lim: [-10, 10] # drift limits in the y-direction
drift_z_lim: [20, 80] # drift limits in the z-direction
drift_xlim: [-10, 10] # drift limits in the x-direction
drift_ylim: [-10, 10] # drift limits in the y-direction
drift_zlim: [20, 80] # drift limits in the z-direction
Running template generation using Python
Expand Down
6 changes: 3 additions & 3 deletions src/MEArec/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,17 @@ def generate_drift_dict_from_params(
# triangle / sine frequency depends on the velocity
freq = 1.0 / (2 * half_period)

times = np.arange(end_drift_index - start_drift_index) / drift_fs
drift_times = np.arange(end_drift_index - start_drift_index) / drift_fs

if slow_drift_waveform == "triangluar":
triangle = np.abs(scipy.signal.sawtooth(2 * np.pi * freq * times + np.pi / 2))
triangle = np.abs(scipy.signal.sawtooth(2 * np.pi * freq * drift_times + np.pi / 2))
triangle *= slow_drift_amplitude
triangle -= slow_drift_amplitude / 2.0

drift_vector_um[start_drift_index:end_drift_index] = triangle
drift_vector_um[end_drift_index:] = triangle[-1]
elif slow_drift_waveform == "sine":
sine = np.cos(2 * np.pi * freq * times + np.pi / 2)
sine = np.cos(2 * np.pi * freq * drift_times + np.pi / 2)
sine *= slow_drift_amplitude / 2.0

drift_vector_um[start_drift_index:end_drift_index] = sine
Expand Down
6 changes: 4 additions & 2 deletions src/MEArec/generators/recordinggenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,9 @@ def generate_recordings(
template_locs = np.array(locs)[reordered_idx_cells]
template_rots = np.array(rots)[reordered_idx_cells]
template_bin = np.array(bin_cat)[reordered_idx_cells]
templates = np.array(eaps)[reordered_idx_cells]
templates = np.empty((len(reordered_idx_cells), *eaps.shape[1:]), dtype=eaps.dtype)
for i, reordered_idx in enumerate(reordered_idx_cells):
templates[i] = eaps[reordered_idx]
self.template_ids = reordered_idx_cells
else:
print(f"Using provided template ids: {self.template_ids}")
Expand Down Expand Up @@ -991,7 +993,7 @@ def generate_recordings(

if verbose_1:
print("Smoothing templates")
templates = templates * window
templates *= window

# delete temporary preprocessed templates
del templates_rs, templates_pad
Expand Down
6 changes: 3 additions & 3 deletions src/MEArec/simulate_cells.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,7 +875,7 @@ def calc_extracellular(
if verbose >= 1:
print(f"Done generating EAPs for {cell_name}")

saved_eaps = np.array(saved_eaps)
saved_eaps = np.array(saved_eaps, dtype=np.float32)
saved_positions = np.array(saved_positions)
saved_rotations = np.array(saved_rotations)

Expand Down Expand Up @@ -1268,14 +1268,14 @@ def check_solidangle(matrix, pre, post, polarlim):
cell.set_rotation(x=x_rot, y=y_rot, z=z_rot)
rot = [x_rot, y_rot, z_rot]

lfp = electrodes.get_transformation_matrix() @ cell.imem
lfp = np.array(electrodes.get_transformation_matrix() @ cell.imem, dtype=np.float32)

# Reverse rotation to bring cell back into initial rotation state
if rotation is not None:
rev_rot = [-r for r in rot]
cell.set_rotation(rev_rot[0], rev_rot[1], rev_rot[2], rotation_order="zyx")

return 1000 * lfp, pos, rot, found_position
return 1e3 * lfp, pos, rot, found_position


def str2bool(v):
Expand Down
25 changes: 18 additions & 7 deletions src/MEArec/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose
print("loading cell type: ", f)
if celltypes is not None:
if celltype in celltypes:
eaps = np.load(str(eaplist[idx]))
eaps = np.load(str(eaplist[idx]), mmap_mode="r")
locs = np.load(str(loclist[idx]))
rots = np.load(str(rotlist[idx]))

Expand All @@ -230,7 +230,7 @@ def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose
else:
ignored_categories.add(celltype)
else:
eaps = np.load(str(eaplist[idx]))
eaps = np.load(str(eaplist[idx]), mmap_mode="r")
locs = np.load(str(loclist[idx]))
rots = np.load(str(rotlist[idx]))

Expand All @@ -245,10 +245,17 @@ def load_tmp_eap(templates_folder, celltypes=None, samples_per_cat=None, verbose
cat_list.extend([celltype] * samples_to_read)
loaded_categories.add(celltype)

if len(eap_list) > 0:
all_eaps = np.lib.format.open_memmap(templates_folder / "all_eaps.npy", mode="w+", dtype=eaps[0].dtype, shape=(len(eap_list), *eap_list[0].shape))
for i in range(len(eap_list)):
all_eaps[i, ...] = eap_list[i]
else:
all_eaps = np.array([])

if verbose:
print("Done loading spike data ...")

return np.array(eap_list), np.array(loc_list), np.array(rot_list), np.array(cat_list, dtype=str)
return all_eaps, np.array(loc_list), np.array(rot_list), np.array(cat_list, dtype=str)


def load_templates(templates, return_h5_objects=True, verbose=False, check_suffix=True):
Expand Down Expand Up @@ -553,7 +560,7 @@ def save_template_generator(tempgen, filename=None, verbose=True):
print("\nSaved templates in", filename, "\n")


def save_recording_generator(recgen, filename=None, verbose=False):
def save_recording_generator(recgen, filename=None, verbose=False, include_spike_traces: bool = True):
"""
Save recordings to disk.
Expand All @@ -565,6 +572,8 @@ def save_recording_generator(recgen, filename=None, verbose=False):
Path to .h5 file
verbose : bool
If True output is verbose
include_spike_traces: bool, default=True
If True, will include the spike traces (which can be large for many units)
"""
filename = Path(filename)
if not filename.parent.is_dir():
Expand All @@ -573,12 +582,12 @@ def save_recording_generator(recgen, filename=None, verbose=False):
with h5py.File(filename, "w") as f:
f.attrs["mearec_version"] = mearec_version
f.attrs["date"] = datetime.now().strftime("%y-%m-%d %H:%M:%S")
save_recording_to_file(recgen, f)
save_recording_to_file(recgen, f, include_spike_traces=include_spike_traces)
if verbose:
print("\nSaved recordings in", filename, "\n")


def save_recording_to_file(recgen, f, path=""):
def save_recording_to_file(recgen, f, path="", include_spike_traces: bool = True):
"""
Save recordings to file handler.
Expand All @@ -588,6 +597,8 @@ def save_recording_to_file(recgen, f, path=""):
RecordingGenerator object to be saved
filename : _io.TextIOWrapper
File handler
include_spike_traces: bool, default=True
If True, will include the spike traces (can be heavy)
"""
save_dict_to_hdf5(recgen.info, f, path + "info/")
if len(recgen.voltage_peaks) > 0:
Expand All @@ -598,7 +609,7 @@ def save_recording_to_file(recgen, f, path=""):
f.create_dataset(path + "recordings", data=recgen.recordings)
if recgen.gain_to_uV is not None:
f["recordings"].attrs["gain_to_uV"] = recgen.gain_to_uV
if len(recgen.spike_traces) > 0:
if len(recgen.spike_traces) > 0 and include_spike_traces:
f.create_dataset(path + "spike_traces", data=recgen.spike_traces)
if len(recgen.spiketrains) > 0:
for ii in range(len(recgen.spiketrains)):
Expand Down

0 comments on commit 57483d7

Please sign in to comment.