Skip to content

Commit

Permalink
Fixed multiple little bugs
Browse files Browse the repository at this point in the history
Fixed a bug in `export_to_phy` where the properties would be linked to
the wrong cluster_id
Fixed a bug in `remove_bad_units` when plotting (removed sparsity)
Allowed the loading of `NumpyFolderSorting`
  • Loading branch information
DradeAW committed Nov 16, 2023
1 parent 8d57941 commit acf0961
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/lussac/core/lussac_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def _load_sortings(sortings_path: dict[str, str]) -> dict[str, si.BaseSorting]:

if not path.exists():
raise FileNotFoundError(f"Could not find the sorting file {path}.")
elif path.is_dir():
elif path.is_dir() and (path / "spike_times.npy").exists():
sorting = se.PhySortingExtractor(path)
else:
sorting = si.load_extractor(path, base_folder=True)
Expand Down
4 changes: 3 additions & 1 deletion src/lussac/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,9 @@ def _save_sortings(self, module_name: str) -> None:

for name, sorting in self.data.sortings.items():
path = f"{self.data.logs_folder}/{module_name}/sorting/{name}.pkl"
# sorting.dump_to_pickle(file_path=path, include_properties=True, relative_to=self.data.logs_folder)
sorting.dump_to_pickle(file_path=path, include_properties=True)
# TODO: Make relative paths work with pickle in SI using hashing.
# TODO: Make relative paths work with pickle in SI.

def _load_sortings(self, module_name: str) -> dict[str, si.BaseSorting]:
"""
Expand All @@ -171,6 +172,7 @@ def _load_sortings(self, module_name: str) -> dict[str, si.BaseSorting]:

logging.info("Loading sortings from previous run...\n")
sortings_path = glob.glob(f"{self.data.logs_folder}/{module_name}/sorting/*.pkl")
# sortings = {pathlib.Path(path).stem: si.load_extractor(path, base_folder=self.data.logs_folder) for path in tqdm(sortings_path)}
sortings = {pathlib.Path(path).stem: si.load_extractor(path) for path in tqdm(sortings_path)}

return sortings
Expand Down
2 changes: 1 addition & 1 deletion src/lussac/modules/export_to_phy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def run(self, params: dict[str, Any]) -> si.BaseSorting:

for property_name in self.sorting.get_property_keys():
if property_name.startswith('lussac_'):
unit_ids = new_unit_ids['si_unit_id'][np.argmax(new_unit_ids['cluster_id'].values == self.sorting.unit_ids[:, None], axis=1)].values
unit_ids = new_unit_ids['cluster_id'][np.argmax(new_unit_ids['si_unit_id'].values == self.sorting.unit_ids[:, None], axis=1)].values
self.write_tsv_file(output_folder / f"{property_name}.tsv", property_name, unit_ids, self.sorting.get_property(property_name))

if 'estimate_contamination' in params:
Expand Down
2 changes: 1 addition & 1 deletion src/lussac/modules/remove_bad_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@ def _plot_bad_units(self, bad_sorting: si.BaseSorting, reasons_for_removal: list
The sorting object containing the bad units.
"""

wvf_extractor = self.extract_waveforms(sorting=bad_sorting, ms_before=1.5, ms_after=2.5, max_spikes_per_unit=500)
wvf_extractor = self.extract_waveforms(sorting=bad_sorting, ms_before=1.5, ms_after=2.5, max_spikes_per_unit=500, sparse=False)
annotations = [{'text': reason, 'x': 0.6, 'y': 1.02, 'xref': "paper", 'yref': "paper", 'xanchor': "center", 'yanchor': "bottom", 'showarrow': False} for reason in reasons_for_removal]
utils.plot_units(wvf_extractor, filepath=f"{self.logs_folder}/bad_units", annotations_change=annotations)

0 comments on commit acf0961

Please sign in to comment.