diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9cc1129 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black + files: ^src/ diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 8bae8fa..8ebc5e4 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -211,7 +211,7 @@ def _plot_scores( if indices is True, the function will return the cluster ids as displayed in the z bar graph """ - + if data == "zscore": z_scores = self.data.z_scores elif data == "raw-data": @@ -259,7 +259,7 @@ def _plot_scores( else: RESET_INDEX = False - assert isinstance(sorting_index, (list,int)), "sorting_index must be list or int" + assert isinstance(sorting_index, (list, int)), "sorting_index must be list or int" if isinstance(sorting_index, list): current_sorting_index = sorting_index[stim_idx] else: @@ -352,7 +352,13 @@ def _plot_scores( if indices: return sorted_cluster_ids - def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): + def plot_raster( + self, + window: Union[list, list[list]], + show_stim: bool = True, + include_ids: list | np.nadarry | None = None, + sorted: bool = False, + ): """ Function to plot rasters @@ -363,6 +369,8 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): of [start, stop] format show_stim: bool, default True Show lines where stim onset and offset are + include_ids: list | np.ndarray | None, default: None + sub ids to include """ from .analysis_utils import histogram_functions as hf @@ -382,6 +390,15 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): event_lengths = self._get_event_lengths() + if include_ids is not None: + cluster_indices = self.data.cluster_ids + keep_list = [] + for cid in include_ids: + keep_list.append(np.where(cluster_indices == cid)[0][0]) + keep_list = np.array(keep_list) + else: + keep_list = np.arange(0, len(cluster_indices), 1) + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -394,8 +411,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])] - for idx in range(np.shape(psth)[0]): - psth_sub = np.squeeze(psth[idx]) + for idy in range(np.shape(psth)[0]): + if idy not in keep_list: + continue + psth_sub = np.squeeze(psth[idy]) if np.sum(psth_sub) == 0: continue @@ -436,7 +455,8 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): sns.despine() else: self._despine(ax) - plt.title(f"{self.data.cluster_ids[idx]} stim: {stimulus}", size=7) + + plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() @@ -446,6 +466,8 @@ def plot_sm_fr( time_bin_ms: Union[float, list[float]], sm_time_ms: Union[float, list[float]], show_stim: bool = True, + include_ids: list | np.ndarray | None = None, + sorted: bool = False, ): """ Function to plot smoothed firing rates @@ -462,6 +484,8 @@ def plot_sm_fr( stimulus show_stim: bool, default True Show lines where stim onset and offset are + include_ids: list | np.ndarray | None + The ids to include for plotting """ import matplotlib as mpl @@ -498,6 +522,15 @@ def plot_sm_fr( number of bins is{len(time_bin_ms)} and should be {NUM_STIM}" time_bin_size = np.array(time_bin_ms) / 1000 + if include_ids is not None: + cluster_indices = self.data.cluster_ids + keep_list = [] + for cid in include_ids: + keep_list.append(np.where(cluster_indices == cid)[0][0]) + keep_list = np.array(keep_list) + else: + keep_list = np.arange(0, len(cluster_indices), 1) + stim_trial_groups = self._get_trial_groups() event_lengths = self._get_event_lengths_all() for idx, stimulus in enumerate(psths.keys()): @@ -529,6 +562,8 @@ def plot_sm_fr( stderr = np.zeros((len(tg_set), len(bins))) event_len = np.zeros((len(tg_set))) for cluster_number in range(np.shape(psth)[0]): + if cluster_number not in keep_list: + continue smoothed_psth = gaussian_smoothing(psth[cluster_number], bin_size, sm_std) for trial_number, trial in enumerate(tg_set): diff --git a/src/spikeanalysis/stimulus_data.py b/src/spikeanalysis/stimulus_data.py index 2cb994f..d2f2524 100644 --- a/src/spikeanalysis/stimulus_data.py +++ b/src/spikeanalysis/stimulus_data.py @@ -1,14 +1,15 @@ from __future__ import annotations import json import os -from .utils import NumpyEncoder -from typing import Optional, Union +from pathlib import Path +import warnings import neo import numpy as np - from tqdm import tqdm +from .utils import NumpyEncoder + class StimulusData: """Class for preprocessing stimulus data for spike train analysis""" @@ -16,7 +17,7 @@ class StimulusData: def __init__(self, file_path: str): """Enter the file_path as a string. For Windows prepend with r to prevent spurious escaping. A Path object can also be given, but make sure it was generated with a raw string""" - from pathlib import Path + import glob import os @@ -75,10 +76,10 @@ def get_all_files(self): def run_all( self, - stim_index: Optional[int] = None, - stim_length_seconds: Optional[float] = None, - stim_name: Optional[list] = None, - time_slice=(None, None), + stim_index: int | None = None, + stim_length_seconds: float | None = None, + stim_name: list | None = None, + time_slice: tuple = (None, None), ): """ Pipeline function to run through all steps necessary to load intan data @@ -128,7 +129,7 @@ def run_all( del self.reader # reader and memmap heavy. Delete after this since not needed - def create_neo_reader(self, file_name: Optional[str] = None): + def create_neo_reader(self, file_name: str | Path | None = None): """ Function that creates a Neo IntanRawIO reader and then parses the header @@ -156,6 +157,8 @@ def create_neo_reader(self, file_name: Optional[str] = None): sample_freq = value[2] break self.sample_frequency = sample_freq + + self.start_timestamp = reader._raw_data["timestamp"].flatten()[0] self.reader = reader def get_analog_data(self, time_slice: tuple = (None, None)): @@ -200,9 +203,9 @@ def get_analog_data(self, time_slice: tuple = (None, None)): def digitize_analog_data( self, - analog_index: Optional[int] = None, - stim_length_seconds: Optional[float] = None, - stim_name: Optional[list[str]] = None, + analog_index: int | None = None, + stim_length_seconds: float | None = None, + stim_name: list[str] | None = None, ): """Function to digitize the analog data for stimuli that have "events" rather than assessing them as continually changing values""" @@ -410,9 +413,9 @@ def set_stimulus_name(self, stim_names: dict): def generate_stimulus_trains( self, - channel_name: Union[str, list[str]], - stim_freq: Union[float, list[float]], - stim_time_secs: Union[float, list[float]], + channel_name: str | list[str], + stim_freq: float | list[float], + stim_time_secs: float | list[float], ): """ Function for converting events into event trains, eg for optogenetic stimulus trains @@ -460,7 +463,7 @@ def save_events(self): for dig_channel, event_type in digital_events.items(): assert ( "stim" in event_type.keys() - ), f"Mst provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}" + ), f"Must provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}" with open("digital_events.json", "w") as write_file: json.dump(self.digital_events, write_file, cls=NumpyEncoder) except AttributeError: @@ -478,8 +481,8 @@ def delete_events( self, del_index: int | list[int], digital: bool = True, - channel_name: Optional[str] = None, - channel_index: Optional[int] = None, + channel_name: str | None = None, + channel_index: str | None = None, ): """ Function for deleting a spurious event, eg, an accident trigger event @@ -517,7 +520,7 @@ def delete_events( else: self.dig_analog_events[key] = data_to_clean - def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice=(None, None)) -> np.array: + def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice: tuple = (None, None)) -> np.array: """ Utility function that hacks the Neo memmap structure to be able to read digital events. @@ -587,3 +590,136 @@ def _calculate_events(self, array: np.array) -> tuple[np.array, np.array]: lengths = offset - onset return onset, lengths + + +class TimestampReader: + + """utility class for helping load non-synced timestamp based data with leading-edge falling-edge.""" + + def __init__( + self, + data: list | np.ndarray, + timestamps: list | np.ndarray, + start_timestamp: float = 0.0, + sample_rate: int | None = None, + ): + """ + Parameters + ---------- + data: list | np.ndarray + An array containing the TTL style data of 0s and some int + timestamps: list | np.ndarray + A timestamp for each value given in data + start_timestamp: float, default: 0.0 + The starting timestamp to sync the data to a sample time scale + sample_rate int | None, default: None + The sample rate to convert from time into samples""" + + self.data = np.array(data) + self.timestamps = np.array(timestamps) + self._start_timestamp = start_timestamp + self._sample_rate = sample_rate + + def set_start_timestamp(self, start_ts: float | StimulusData): + """ + Function to set the timestamp offset + Parameters + ---------- + start_ts: float | StimulusData + The start timestamp to offset the analysis with""" + + if isinstance(start_ts, (float, int)): + self._start_timestamp = float(start_ts) + elif isinstance(start_ts, StimulusData): + self._start_timestamp = start_ts.start_timestamp + else: + raise TypeError(f"`start_ts` must be float or StimulusData. It is of type {type(start_ts)}") + + def set_sample_rate(self, sample_rate: int | StimulusData): + """ + Function to set the sample rate + Parameters + ---------- + sample_rate: int | StimulusData + The sample rate to convert from time to samples""" + + if isinstance(sample_rate, (float, int)): + self._sample_rate = sample_rate + elif isinstance(sample_rate, StimulusData): + self._sample_rate = sample_rate.sample_frequency + else: + raise TypeError(f"`start_ts` must be int or StimulusData. It is of type {type(sample_rate)}") + + def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_place: bool = True): + """Function which loads a timestamp TTL into StimulusData + Parameters + ---------- + stim: StimulusData + The StimulusData object to use + new_stim_key: str + The key value to use in the `digital_events` dictionary + in_place: bool, default=True + If true loads the new events into the current StimulusData + If false returns a deep copy with the new data loaded + Returns + ------- + stim1: StimulusData + If in_place set to false it returns a deepcopy of the StimulusData with + the new events loaded""" + + assert isinstance(stim, StimulusData), "function is for loading into StimulusData" + try: + assert ( + new_stim_key not in stim.digital_events + ), f"`new_stim_key` must be new key current keys are {stim.digital_events.keys()}" + except AttributeError: + warnings.warn( + "This function should be run after all other stimulus data has been processed but before setting trial groups and names" + ) + stim.digital_events = {} + + onsets, lengths = self._calculate_events() + + if in_place: + stim.digital_events[new_stim_key] = {} + stim.digital_events[new_stim_key]["onsets"] = onsets + stim.digital_events[new_stim_key]["lengths"] = lengths + stim.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets))) + else: + import copy + + stim1 = copy.deepcopy(stim) + stim1.digital_events[new_stim_key] = {} + stim1.digital_events[new_stim_key]["onsets"] = onsets + stim1.digital_events[new_stim_key]["lengths"] = lengths + stim1.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets))) + return stim1 + + def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]: + """Function to convert from timestamps to samples as well as a leading/falling edge detector + Returns + ------- + onset_samples: np.ndarray + The onset of events in samples + lengths: np.ndarray + the lengths of the events in samples""" + + assert self._sample_rate, "`sample_rate` must be set to calculate events, use `set_sample_rate()`" + + timestamps = self.timestamps - self._start_timestamp + onset = np.where(np.diff(self.data) < 0)[0] + offset = np.where(np.diff(self.data) > 0)[0] + if self.data[0] > 0: + onset = np.pad(onset, (1, 0), "constant", constant_values=0) + if self.data[-1] > 0: + offset = np.pad(offset, (0, 1), "constant", constant_value=self.data[-1]) + + onset_timestamps = timestamps[onset] + offset_timestamps = timestamps[offset] + + onset_samples = onset_timestamps * self._sample_rate + offset_samples = offset_timestamps * self._sample_rate + + lengths = onset_samples - offset_samples + + return onset_samples, lengths diff --git a/test/test_stimulus_data.py b/test/test_stimulus_data.py index 22185fb..6fbcd30 100644 --- a/test/test_stimulus_data.py +++ b/test/test_stimulus_data.py @@ -4,6 +4,7 @@ from pathlib import Path from spikeanalysis.stimulus_data import StimulusData +from spikeanalysis.stimulus_data import TimestampReader @pytest.fixture @@ -115,6 +116,10 @@ def test_json_writer(ana_stim, tmp_path): stim1 = ana_stim stim1.digitize_analog_data(stim_length_seconds=0.1) print(stim1.dig_analog_events) + stim1.get_raw_digital_data() + stim1.get_final_digital_data() + stim1.generate_digital_events() + stim1.set_stimulus_name({"DIGITAL-IN-01": "testdig", "DIGITAL-IN-02": "testdig2"}) print(stim1._file_path) stim1._file_path = stim1._file_path / tmp_path print(stim1._file_path) @@ -302,12 +307,18 @@ def test_set_stimulus_name(stim): assert stim.digital_events["DIGITAL-IN-01"]["stim"] == "TEST" -def test_delete_events(stim): - stim.get_raw_digital_data() - stim.get_final_digital_data() - stim.generate_digital_events() - stim.delete_events(del_index=1, channel_name="DIGITAL-IN-01") - assert len(stim.digital_events["DIGITAL-IN-01"]["events"]) == 20 +def test_delete_events(ana_stim): + import copy + + ana_stim.digitize_analog_data(stim_length_seconds=0.1, analog_index=0, stim_name=["test"]) + stim1 = copy.deepcopy(ana_stim) + stim1.get_raw_digital_data() + stim1.get_final_digital_data() + stim1.generate_digital_events() + stim1.delete_events(del_index=1, channel_name="DIGITAL-IN-01") + assert len(stim1.digital_events["DIGITAL-IN-01"]["events"]) == 20 + stim1.delete_events(del_index=1, digital=False, channel_index="0") + assert len(stim1.dig_analog_events["0"]["events"]) == 1 def test_run_all(stim): @@ -316,3 +327,41 @@ def test_run_all(stim): assert stim.analog_data.any() assert isinstance(stim.dig_analog_events, dict) assert isinstance(stim.digital_events, dict) + + +def test_timestamp_reader(stim): + data = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0] + timestamps = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] + + reader = TimestampReader(data=data, timestamps=timestamps, start_timestamp=1, sample_rate=10) + + assert reader._start_timestamp == 1 + assert reader._sample_rate == 10 + assert isinstance(reader.data, np.ndarray), "should save as ndarray" + + reader.set_sample_rate(stim) + assert reader._sample_rate == 3000.0 + reader.set_sample_rate(sample_rate=1) + assert reader._sample_rate == 1, "sample rate should be reset" + reader.set_start_timestamp(stim) + assert reader._start_timestamp == 0 + reader.set_start_timestamp(start_ts=0) + assert reader._start_timestamp == 0 + + import copy + + stim_no_event = copy.deepcopy(stim) + stim.get_raw_digital_data() + stim.get_final_digital_data() + stim.generate_digital_events() + stim1 = copy.deepcopy(stim) + reader.load_into_stimulus_data(stim=stim1, new_stim_key="testdata", in_place=True) + + assert "testdata" in stim1.digital_events.keys() + + stim2 = copy.deepcopy(stim) + stim3 = reader.load_into_stimulus_data(stim=stim2, new_stim_key="testdata2", in_place=False) + assert "testdata2" not in stim2.digital_events.keys() + assert "testdata2" in stim3.digital_events.keys() + + reader.load_into_stimulus_data(stim=stim_no_event, new_stim_key="warning_test", in_place=True)