From 585c2cec016f9d1ad05e8de762823e8b7a2cdf50 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 1 Nov 2023 14:22:55 -0400 Subject: [PATCH 01/10] WIP--allowing loading from any set of timestamps --- src/spikeanalysis/stimulus_data.py | 163 +++++++++++++++++++++++++---- 1 file changed, 145 insertions(+), 18 deletions(-) diff --git a/src/spikeanalysis/stimulus_data.py b/src/spikeanalysis/stimulus_data.py index 2cb994f..ad46c30 100644 --- a/src/spikeanalysis/stimulus_data.py +++ b/src/spikeanalysis/stimulus_data.py @@ -1,14 +1,15 @@ from __future__ import annotations import json import os -from .utils import NumpyEncoder -from typing import Optional, Union +from pathlib import Path +import warnings import neo import numpy as np - from tqdm import tqdm +from .utils import NumpyEncoder + class StimulusData: """Class for preprocessing stimulus data for spike train analysis""" @@ -16,7 +17,7 @@ class StimulusData: def __init__(self, file_path: str): """Enter the file_path as a string. For Windows prepend with r to prevent spurious escaping. A Path object can also be given, but make sure it was generated with a raw string""" - from pathlib import Path + import glob import os @@ -75,10 +76,10 @@ def get_all_files(self): def run_all( self, - stim_index: Optional[int] = None, - stim_length_seconds: Optional[float] = None, - stim_name: Optional[list] = None, - time_slice=(None, None), + stim_index: int | None = None, + stim_length_seconds: float | None = None, + stim_name: list | None = None, + time_slice: tuple =(None, None), ): """ Pipeline function to run through all steps necessary to load intan data @@ -128,7 +129,7 @@ def run_all( del self.reader # reader and memmap heavy. Delete after this since not needed - def create_neo_reader(self, file_name: Optional[str] = None): + def create_neo_reader(self, file_name: str | Path | None = None): """ Function that creates a Neo IntanRawIO reader and then parses the header @@ -156,6 +157,8 @@ def create_neo_reader(self, file_name: Optional[str] = None): sample_freq = value[2] break self.sample_frequency = sample_freq + + self.start_timestamp = reader._raw_data['timestamp'].flatten()[0] self.reader = reader def get_analog_data(self, time_slice: tuple = (None, None)): @@ -200,9 +203,9 @@ def get_analog_data(self, time_slice: tuple = (None, None)): def digitize_analog_data( self, - analog_index: Optional[int] = None, - stim_length_seconds: Optional[float] = None, - stim_name: Optional[list[str]] = None, + analog_index: int | None = None, + stim_length_seconds: float | None = None, + stim_name: list[str] | None = None, ): """Function to digitize the analog data for stimuli that have "events" rather than assessing them as continually changing values""" @@ -410,9 +413,9 @@ def set_stimulus_name(self, stim_names: dict): def generate_stimulus_trains( self, - channel_name: Union[str, list[str]], - stim_freq: Union[float, list[float]], - stim_time_secs: Union[float, list[float]], + channel_name: str | list[str], + stim_freq: float | list[float], + stim_time_secs: float | list[float], ): """ Function for converting events into event trains, eg for optogenetic stimulus trains @@ -478,8 +481,8 @@ def delete_events( self, del_index: int | list[int], digital: bool = True, - channel_name: Optional[str] = None, - channel_index: Optional[int] = None, + channel_name: str | None= None, + channel_index: str | None = None, ): """ Function for deleting a spurious event, eg, an accident trigger event @@ -517,7 +520,7 @@ def delete_events( else: self.dig_analog_events[key] = data_to_clean - def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice=(None, None)) -> np.array: + def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice: tuple =(None, None)) -> np.array: """ Utility function that hacks the Neo memmap structure to be able to read digital events. @@ -587,3 +590,127 @@ def _calculate_events(self, array: np.array) -> tuple[np.array, np.array]: lengths = offset - onset return onset, lengths + + + +class TimestampReader: + + """utility class for helping load non-synced timestamp based data with leading-edge falling-edge.""" + + def __init__(self, data: list | np.ndarray, timestamps: list | np.ndarray, start_timestamp: float = 0.0, sample_rate: int | None = None): + """ + Parameters + ---------- + data: list | np.ndarray + An array containing the TTL style data of 0s and some int + timestamps: list | np.ndarray + A timestamp for each value given in data + start_timestamp: float, default: 0.0 + The starting timestamp to sync the data to a sample time scale + sample_rate int | None, default: None + The sample rate to convert from time into samples""" + + self.data = np.array(data) + self.timestamps = np.array(timestamps) + self._start_timestamp = start_timestamp + self._sample_rate = sample_rate + + def set_start_timestamp(self, start_ts: float | StimulusData): + """ + Function to set the timestamp offset + Parameters + ---------- + start_ts: float | StimulusData + The start timestamp to offset the analysis with""" + + if isinstance(start_ts, (float, int)): + self._start_timestamp = float(start_ts) + elif isinstance(start_ts, StimulusData): + self._start_timestamp = start_ts.start_timestamp + else: + raise TypeError(f'`start_ts` must be float or StimulusData. It is of type {type(start_ts)}') + + def set_sample_rate(self, sample_rate: int | StimulusData): + """ + Function to set the sample rate + Parameters + ---------- + sample_rate: int | StimulusData + The sample rate to convert from time to samples""" + + if isinstance(sample_rate, (float, int)): + self._sample_rate = sample_rate + elif isinstance(sample_rate, StimulusData): + self._sample_rate = sample_rate.sample_frequency + else: + raise TypeError(f'`start_ts` must be int or StimulusData. It is of type {type(sample_rate)}') + + + def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_place:bool=True): + + """Function which loads a timestamp TTL into StimulusData + Parameters + ---------- + stim: StimulusData + The StimulusData object to use + new_stim_key: str + The key value to use in the `digital_events` dictionary + in_place: bool, default=True + If true loads the new events into the current StimulusData + If false returns a deep copy with the new data loaded + Returns + ------- + stim1: StimulusData + If in_place set to false it returns a deepcopy of the StimulusData with + the new events loaded""" + + assert isinstance(stim, StimulusData), "function is for loading into StimulusData" + try: + assert new_stim_key in stim.digital_events, f'`new_stim_key` must be new key current keys are {stim.digital_events.keys()}' + except AttributeError: + warnings.warn('This function should be run after all other stimulus data has been processed but before setting trial groups and names') + + onsets, lengths = self._calculate_events() + + if in_place: + stim.digital_events[new_stim_key]['onsets'] = onsets + stim.digital_events[new_stim_key]['lengths'] = lengths + stim.digital_events[new_stim_key]['trial_groups'] = np.ones((len(onsets))) + else: + import copy + stim1 = copy.deepcopy(stim) + stim1.digital_events[new_stim_key]['onsets'] = onsets + stim1.digital_events[new_stim_key]['lengths'] = lengths + stim1.digital_events[new_stim_key]['trial_groups'] = np.ones((len(onsets))) + return stim1 + + def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]: + + """Function to convert from timestamps to samples as well as a leading/falling edge detector + Returns + ------- + onset_samples: np.ndarray + The onset of events in samples + lengths: np.ndarray + the lengths of the events in samples""" + + assert self._sample_rate, f"`sample_rate` must be set to calculate events, use `set_sample_rate()`" + + timestamps = self.timestamps - self._start_timestamp + onset = np.where(np.diff(self.data) < 0)[0] + offset = np.where(np.diff(self.data) > 0)[0] + if self.data[0] > 0: + onset = np.pad(onset, (1,0), "constant", constant_values=0) + if self.data[-1] > 0: + offset = np.pad(offset, (0,1), "constant", constant_value = self.data[-1]) + + onset_timestamps = timestamps[onset] + offset_timestamps = timestamps[offset] + + onset_samples = onset_timestamps * self._sample_rate + offset_samples = offset_timestamps * self._sample_rate + + lengths= onset_samples - offset_samples + + return onset_samples, lengths + From c97a4905b957d351f9aa4394320921d3aa8404c0 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 1 Nov 2023 14:24:34 -0400 Subject: [PATCH 02/10] linting --- src/spikeanalysis/stimulus_data.py | 72 ++++++++++++++++-------------- 1 file changed, 39 insertions(+), 33 deletions(-) diff --git a/src/spikeanalysis/stimulus_data.py b/src/spikeanalysis/stimulus_data.py index ad46c30..51d69fe 100644 --- a/src/spikeanalysis/stimulus_data.py +++ b/src/spikeanalysis/stimulus_data.py @@ -17,7 +17,7 @@ class StimulusData: def __init__(self, file_path: str): """Enter the file_path as a string. For Windows prepend with r to prevent spurious escaping. A Path object can also be given, but make sure it was generated with a raw string""" - + import glob import os @@ -79,7 +79,7 @@ def run_all( stim_index: int | None = None, stim_length_seconds: float | None = None, stim_name: list | None = None, - time_slice: tuple =(None, None), + time_slice: tuple = (None, None), ): """ Pipeline function to run through all steps necessary to load intan data @@ -158,7 +158,7 @@ def create_neo_reader(self, file_name: str | Path | None = None): break self.sample_frequency = sample_freq - self.start_timestamp = reader._raw_data['timestamp'].flatten()[0] + self.start_timestamp = reader._raw_data["timestamp"].flatten()[0] self.reader = reader def get_analog_data(self, time_slice: tuple = (None, None)): @@ -481,7 +481,7 @@ def delete_events( self, del_index: int | list[int], digital: bool = True, - channel_name: str | None= None, + channel_name: str | None = None, channel_index: str | None = None, ): """ @@ -520,7 +520,7 @@ def delete_events( else: self.dig_analog_events[key] = data_to_clean - def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice: tuple =(None, None)) -> np.array: + def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice: tuple = (None, None)) -> np.array: """ Utility function that hacks the Neo memmap structure to be able to read digital events. @@ -592,12 +592,17 @@ def _calculate_events(self, array: np.array) -> tuple[np.array, np.array]: return onset, lengths - class TimestampReader: """utility class for helping load non-synced timestamp based data with leading-edge falling-edge.""" - def __init__(self, data: list | np.ndarray, timestamps: list | np.ndarray, start_timestamp: float = 0.0, sample_rate: int | None = None): + def __init__( + self, + data: list | np.ndarray, + timestamps: list | np.ndarray, + start_timestamp: float = 0.0, + sample_rate: int | None = None, + ): """ Parameters ---------- @@ -609,7 +614,7 @@ def __init__(self, data: list | np.ndarray, timestamps: list | np.ndarray, start The starting timestamp to sync the data to a sample time scale sample_rate int | None, default: None The sample rate to convert from time into samples""" - + self.data = np.array(data) self.timestamps = np.array(timestamps) self._start_timestamp = start_timestamp @@ -628,8 +633,8 @@ def set_start_timestamp(self, start_ts: float | StimulusData): elif isinstance(start_ts, StimulusData): self._start_timestamp = start_ts.start_timestamp else: - raise TypeError(f'`start_ts` must be float or StimulusData. It is of type {type(start_ts)}') - + raise TypeError(f"`start_ts` must be float or StimulusData. It is of type {type(start_ts)}") + def set_sample_rate(self, sample_rate: int | StimulusData): """ Function to set the sample rate @@ -643,11 +648,9 @@ def set_sample_rate(self, sample_rate: int | StimulusData): elif isinstance(sample_rate, StimulusData): self._sample_rate = sample_rate.sample_frequency else: - raise TypeError(f'`start_ts` must be int or StimulusData. It is of type {type(sample_rate)}') - - - def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_place:bool=True): + raise TypeError(f"`start_ts` must be int or StimulusData. It is of type {type(sample_rate)}") + def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_place: bool = True): """Function which loads a timestamp TTL into StimulusData Parameters ---------- @@ -661,31 +664,35 @@ def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_plac Returns ------- stim1: StimulusData - If in_place set to false it returns a deepcopy of the StimulusData with + If in_place set to false it returns a deepcopy of the StimulusData with the new events loaded""" - + assert isinstance(stim, StimulusData), "function is for loading into StimulusData" try: - assert new_stim_key in stim.digital_events, f'`new_stim_key` must be new key current keys are {stim.digital_events.keys()}' + assert ( + new_stim_key in stim.digital_events + ), f"`new_stim_key` must be new key current keys are {stim.digital_events.keys()}" except AttributeError: - warnings.warn('This function should be run after all other stimulus data has been processed but before setting trial groups and names') + warnings.warn( + "This function should be run after all other stimulus data has been processed but before setting trial groups and names" + ) onsets, lengths = self._calculate_events() if in_place: - stim.digital_events[new_stim_key]['onsets'] = onsets - stim.digital_events[new_stim_key]['lengths'] = lengths - stim.digital_events[new_stim_key]['trial_groups'] = np.ones((len(onsets))) + stim.digital_events[new_stim_key]["onsets"] = onsets + stim.digital_events[new_stim_key]["lengths"] = lengths + stim.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets))) else: import copy + stim1 = copy.deepcopy(stim) - stim1.digital_events[new_stim_key]['onsets'] = onsets - stim1.digital_events[new_stim_key]['lengths'] = lengths - stim1.digital_events[new_stim_key]['trial_groups'] = np.ones((len(onsets))) + stim1.digital_events[new_stim_key]["onsets"] = onsets + stim1.digital_events[new_stim_key]["lengths"] = lengths + stim1.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets))) return stim1 - - def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]: + def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]: """Function to convert from timestamps to samples as well as a leading/falling edge detector Returns ------- @@ -695,22 +702,21 @@ def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]: the lengths of the events in samples""" assert self._sample_rate, f"`sample_rate` must be set to calculate events, use `set_sample_rate()`" - + timestamps = self.timestamps - self._start_timestamp onset = np.where(np.diff(self.data) < 0)[0] offset = np.where(np.diff(self.data) > 0)[0] - if self.data[0] > 0: - onset = np.pad(onset, (1,0), "constant", constant_values=0) + if self.data[0] > 0: + onset = np.pad(onset, (1, 0), "constant", constant_values=0) if self.data[-1] > 0: - offset = np.pad(offset, (0,1), "constant", constant_value = self.data[-1]) - + offset = np.pad(offset, (0, 1), "constant", constant_value=self.data[-1]) + onset_timestamps = timestamps[onset] offset_timestamps = timestamps[offset] onset_samples = onset_timestamps * self._sample_rate offset_samples = offset_timestamps * self._sample_rate - lengths= onset_samples - offset_samples + lengths = onset_samples - offset_samples return onset_samples, lengths - From 7f5aebbcdc2e612321db1e2d3a06a28bc64c6cba Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 1 Nov 2023 15:43:08 -0400 Subject: [PATCH 03/10] add test for new class --- test/test_stimulus_data.py | 61 ++++++++++++++++++++++++++++++++++---- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/test/test_stimulus_data.py b/test/test_stimulus_data.py index 22185fb..8999efc 100644 --- a/test/test_stimulus_data.py +++ b/test/test_stimulus_data.py @@ -4,6 +4,7 @@ from pathlib import Path from spikeanalysis.stimulus_data import StimulusData +from spikeanalysis.stimulus_data import TimestampReader @pytest.fixture @@ -115,6 +116,10 @@ def test_json_writer(ana_stim, tmp_path): stim1 = ana_stim stim1.digitize_analog_data(stim_length_seconds=0.1) print(stim1.dig_analog_events) + stim1.get_raw_digital_data() + stim1.get_final_digital_data() + stim1.generate_digital_events() + stim1.set_stimulus_name({'DIGITAL-IN-01': 'testdig', 'DIGITAL-IN-02': 'testdig2'}) print(stim1._file_path) stim1._file_path = stim1._file_path / tmp_path print(stim1._file_path) @@ -302,12 +307,17 @@ def test_set_stimulus_name(stim): assert stim.digital_events["DIGITAL-IN-01"]["stim"] == "TEST" -def test_delete_events(stim): - stim.get_raw_digital_data() - stim.get_final_digital_data() - stim.generate_digital_events() - stim.delete_events(del_index=1, channel_name="DIGITAL-IN-01") - assert len(stim.digital_events["DIGITAL-IN-01"]["events"]) == 20 +def test_delete_events(ana_stim): + import copy + ana_stim.digitize_analog_data(stim_length_seconds=0.1, analog_index=0, stim_name=["test"]) + stim1 = copy.deepcopy(ana_stim) + stim1.get_raw_digital_data() + stim1.get_final_digital_data() + stim1.generate_digital_events() + stim1.delete_events(del_index=1, channel_name="DIGITAL-IN-01") + assert len(stim1.digital_events["DIGITAL-IN-01"]["events"]) == 20 + stim1.delete_events(del_index=1, digital=False, channel_index='0') + assert len(stim1.dig_analog_events['0']['events'])==1 def test_run_all(stim): @@ -316,3 +326,42 @@ def test_run_all(stim): assert stim.analog_data.any() assert isinstance(stim.dig_analog_events, dict) assert isinstance(stim.digital_events, dict) + +def test_timestamp_reader(stim): + + data = [0,0,0,1,1,1,0,0,0,1,1,1,0,0,0] + timestamps = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14] + + reader = TimestampReader(data = data, timestamps=timestamps, start_timestamp=1, sample_rate=10) + + assert reader._start_timestamp==1 + assert reader._sample_rate ==10 + assert isinstance(reader.data, np.ndarray), 'should save as ndarray' + + reader.set_sample_rate(stim) + assert reader._sample_rate == 3000. + reader.set_sample_rate(sample_rate = 1) + assert reader._sample_rate == 1, 'sample rate should be reset' + reader.set_start_timestamp(stim) + assert reader._start_timestamp==0 + reader.set_start_timestamp(start_ts = 0) + assert reader._start_timestamp == 0 + + import copy + stim_no_event = copy.deepcopy(stim) + stim.get_raw_digital_data() + stim.get_final_digital_data() + stim.generate_digital_events() + stim1 = copy.deepcopy(stim) + reader.load_into_stimulus_data(stim=stim1, new_stim_key='testdata', in_place=True) + + assert 'testdata' in stim1.digital_events.keys() + + stim2 = copy.deepcopy(stim) + stim3 = reader.load_into_stimulus_data(stim=stim2, new_stim_key='testdata2', in_place=False) + assert 'testdata2' not in stim2.digital_events.keys() + assert 'testdata2' in stim3.digital_events.keys() + + reader.load_into_stimulus_data(stim=stim_no_event, new_stim_key='warning_test', in_place=True) + + \ No newline at end of file From 4954bbc1710017804802f3eb2e2640cef5f3a17f Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 1 Nov 2023 15:43:59 -0400 Subject: [PATCH 04/10] allow for new digitalevents creation --- src/spikeanalysis/stimulus_data.py | 9 ++++-- test/test_stimulus_data.py | 46 +++++++++++++++--------------- 2 files changed, 29 insertions(+), 26 deletions(-) diff --git a/src/spikeanalysis/stimulus_data.py b/src/spikeanalysis/stimulus_data.py index 51d69fe..d2f2524 100644 --- a/src/spikeanalysis/stimulus_data.py +++ b/src/spikeanalysis/stimulus_data.py @@ -463,7 +463,7 @@ def save_events(self): for dig_channel, event_type in digital_events.items(): assert ( "stim" in event_type.keys() - ), f"Mst provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}" + ), f"Must provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}" with open("digital_events.json", "w") as write_file: json.dump(self.digital_events, write_file, cls=NumpyEncoder) except AttributeError: @@ -670,16 +670,18 @@ def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_plac assert isinstance(stim, StimulusData), "function is for loading into StimulusData" try: assert ( - new_stim_key in stim.digital_events + new_stim_key not in stim.digital_events ), f"`new_stim_key` must be new key current keys are {stim.digital_events.keys()}" except AttributeError: warnings.warn( "This function should be run after all other stimulus data has been processed but before setting trial groups and names" ) + stim.digital_events = {} onsets, lengths = self._calculate_events() if in_place: + stim.digital_events[new_stim_key] = {} stim.digital_events[new_stim_key]["onsets"] = onsets stim.digital_events[new_stim_key]["lengths"] = lengths stim.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets))) @@ -687,6 +689,7 @@ def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_plac import copy stim1 = copy.deepcopy(stim) + stim1.digital_events[new_stim_key] = {} stim1.digital_events[new_stim_key]["onsets"] = onsets stim1.digital_events[new_stim_key]["lengths"] = lengths stim1.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets))) @@ -701,7 +704,7 @@ def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]: lengths: np.ndarray the lengths of the events in samples""" - assert self._sample_rate, f"`sample_rate` must be set to calculate events, use `set_sample_rate()`" + assert self._sample_rate, "`sample_rate` must be set to calculate events, use `set_sample_rate()`" timestamps = self.timestamps - self._start_timestamp onset = np.where(np.diff(self.data) < 0)[0] diff --git a/test/test_stimulus_data.py b/test/test_stimulus_data.py index 8999efc..6fbcd30 100644 --- a/test/test_stimulus_data.py +++ b/test/test_stimulus_data.py @@ -119,7 +119,7 @@ def test_json_writer(ana_stim, tmp_path): stim1.get_raw_digital_data() stim1.get_final_digital_data() stim1.generate_digital_events() - stim1.set_stimulus_name({'DIGITAL-IN-01': 'testdig', 'DIGITAL-IN-02': 'testdig2'}) + stim1.set_stimulus_name({"DIGITAL-IN-01": "testdig", "DIGITAL-IN-02": "testdig2"}) print(stim1._file_path) stim1._file_path = stim1._file_path / tmp_path print(stim1._file_path) @@ -309,6 +309,7 @@ def test_set_stimulus_name(stim): def test_delete_events(ana_stim): import copy + ana_stim.digitize_analog_data(stim_length_seconds=0.1, analog_index=0, stim_name=["test"]) stim1 = copy.deepcopy(ana_stim) stim1.get_raw_digital_data() @@ -316,8 +317,8 @@ def test_delete_events(ana_stim): stim1.generate_digital_events() stim1.delete_events(del_index=1, channel_name="DIGITAL-IN-01") assert len(stim1.digital_events["DIGITAL-IN-01"]["events"]) == 20 - stim1.delete_events(del_index=1, digital=False, channel_index='0') - assert len(stim1.dig_analog_events['0']['events'])==1 + stim1.delete_events(del_index=1, digital=False, channel_index="0") + assert len(stim1.dig_analog_events["0"]["events"]) == 1 def test_run_all(stim): @@ -327,41 +328,40 @@ def test_run_all(stim): assert isinstance(stim.dig_analog_events, dict) assert isinstance(stim.digital_events, dict) -def test_timestamp_reader(stim): - data = [0,0,0,1,1,1,0,0,0,1,1,1,0,0,0] - timestamps = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14] +def test_timestamp_reader(stim): + data = [0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0] + timestamps = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] - reader = TimestampReader(data = data, timestamps=timestamps, start_timestamp=1, sample_rate=10) + reader = TimestampReader(data=data, timestamps=timestamps, start_timestamp=1, sample_rate=10) - assert reader._start_timestamp==1 - assert reader._sample_rate ==10 - assert isinstance(reader.data, np.ndarray), 'should save as ndarray' + assert reader._start_timestamp == 1 + assert reader._sample_rate == 10 + assert isinstance(reader.data, np.ndarray), "should save as ndarray" reader.set_sample_rate(stim) - assert reader._sample_rate == 3000. - reader.set_sample_rate(sample_rate = 1) - assert reader._sample_rate == 1, 'sample rate should be reset' + assert reader._sample_rate == 3000.0 + reader.set_sample_rate(sample_rate=1) + assert reader._sample_rate == 1, "sample rate should be reset" reader.set_start_timestamp(stim) - assert reader._start_timestamp==0 - reader.set_start_timestamp(start_ts = 0) + assert reader._start_timestamp == 0 + reader.set_start_timestamp(start_ts=0) assert reader._start_timestamp == 0 import copy + stim_no_event = copy.deepcopy(stim) stim.get_raw_digital_data() stim.get_final_digital_data() stim.generate_digital_events() stim1 = copy.deepcopy(stim) - reader.load_into_stimulus_data(stim=stim1, new_stim_key='testdata', in_place=True) + reader.load_into_stimulus_data(stim=stim1, new_stim_key="testdata", in_place=True) - assert 'testdata' in stim1.digital_events.keys() + assert "testdata" in stim1.digital_events.keys() stim2 = copy.deepcopy(stim) - stim3 = reader.load_into_stimulus_data(stim=stim2, new_stim_key='testdata2', in_place=False) - assert 'testdata2' not in stim2.digital_events.keys() - assert 'testdata2' in stim3.digital_events.keys() - - reader.load_into_stimulus_data(stim=stim_no_event, new_stim_key='warning_test', in_place=True) + stim3 = reader.load_into_stimulus_data(stim=stim2, new_stim_key="testdata2", in_place=False) + assert "testdata2" not in stim2.digital_events.keys() + assert "testdata2" in stim3.digital_events.keys() - \ No newline at end of file + reader.load_into_stimulus_data(stim=stim_no_event, new_stim_key="warning_test", in_place=True) From 8edf467eab76653a49e3b20da7a072a1fb047a7b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Mon, 6 Nov 2023 15:47:30 -0500 Subject: [PATCH 05/10] working on keep list --- src/spikeanalysis/spike_plotter.py | 38 ++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index a36fbbc..6a54eb8 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -181,6 +181,7 @@ def _plot_scores( bar: Optional[list[int]] = None, indices: bool = False, show_stim: bool = True, + ) -> Optional[np.array]: """ Function to plot heatmaps of firing rate data @@ -342,7 +343,7 @@ def _plot_scores( if indices: return sorted_cluster_ids - def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): + def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None): """ Function to plot rasters @@ -353,6 +354,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): of [start, stop] format show_stim: bool, default True Show lines where stim onset and offset are + include_ids: list | np.ndarray | None, default: None """ from .analysis_utils import histogram_functions as hf @@ -372,6 +374,16 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): event_lengths = self._get_event_lengths() + + if include_ids is not None: + cluster_indices = self.data.cluster_ids + keep_list = [] + for cid in include_ids: + keep_list.append(np.where(cluster_indices==cid)[0][0]) + keep_list = np.array(keep_list) + else: + keep_list = np.arange(0, len(cluster_indices), 1) + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -384,8 +396,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])] - for idx in range(np.shape(psth)[0]): - psth_sub = np.squeeze(psth[idx]) + for idy in range(np.shape(psth)[0]): + if idy not in keep_list: + pass + psth_sub = np.squeeze(psth[idy]) if np.sum(psth_sub) == 0: continue @@ -426,7 +440,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True): sns.despine() else: self._despine(ax) - plt.title(f"{self.data.cluster_ids[idx]} stim: {stimulus}", size=7) + plt.title(f"{self.data.cluster_ids[idy]} stim: {stimulus}", size=7) plt.figure(dpi=self.dpi) plt.show() @@ -436,6 +450,7 @@ def plot_sm_fr( time_bin_ms: Union[float, list[float]], sm_time_ms: Union[float, list[float]], show_stim: bool = True, + include_ids: list | np.ndarray | None = None, ): """ Function to plot smoothed firing rates @@ -452,7 +467,9 @@ def plot_sm_fr( stimulus show_stim: bool, default True Show lines where stim onset and offset are - + include_ids: list | np.ndarray | None + The ids to include for plotting + """ import matplotlib as mpl from .analysis_utils import histogram_functions as hf @@ -488,6 +505,15 @@ def plot_sm_fr( number of bins is{len(time_bin_ms)} and should be {NUM_STIM}" time_bin_size = np.array(time_bin_ms) / 1000 + if include_ids is not None: + cluster_indices = self.data.cluster_ids + keep_list = [] + for cid in include_ids: + keep_list.append(np.where(cluster_indices==cid)[0][0]) + keep_list = np.array(keep_list) + else: + keep_list = np.arange(0, len(cluster_indices), 1) + stim_trial_groups = self._get_trial_groups() event_lengths = self._get_event_lengths_all() for idx, stimulus in enumerate(psths.keys()): @@ -519,6 +545,8 @@ def plot_sm_fr( stderr = np.zeros((len(tg_set), len(bins))) event_len = np.zeros((len(tg_set))) for cluster_number in range(np.shape(psth)[0]): + if cluster_number not in keep_list: + pass smoothed_psth = gaussian_smoothing(psth[cluster_number], bin_size, sm_std) for trial_number, trial in enumerate(tg_set): From a84781538bb1ffb44bb3953b0f3dea33fe15a1f3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:03:02 -0500 Subject: [PATCH 06/10] add sorted --- src/spikeanalysis/spike_plotter.py | 35 ++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 6a54eb8..94952eb 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -343,7 +343,7 @@ def _plot_scores( if indices: return sorted_cluster_ids - def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None): + def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None, sorted: bool = False): """ Function to plot rasters @@ -355,6 +355,8 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i show_stim: bool, default True Show lines where stim onset and offset are include_ids: list | np.ndarray | None, default: None + sub ids to include + sorted: bool, default = False """ from .analysis_utils import histogram_functions as hf @@ -384,6 +386,12 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i else: keep_list = np.arange(0, len(cluster_indices), 1) + if sorted: + sorted_indices = [] + for value_id in include_ids: + sorted_indices.append(np.nonzero(self.cluster_ids==value_id)[0][0]) + sorted_indices = np.array(sorted_indices) + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -394,6 +402,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i tg_set = np.unique(trial_groups) psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] + + if sorted: + psth = psth[sorted_indices,...] + bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])] for idy in range(np.shape(psth)[0]): @@ -440,7 +452,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i sns.despine() else: self._despine(ax) - plt.title(f"{self.data.cluster_ids[idy]} stim: {stimulus}", size=7) + if sorted: + plt.title(f"{stimulus}: {self.data.cluster_ids[sorted_indices][idy]}", fontsize=8) + else: + plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() @@ -451,6 +466,7 @@ def plot_sm_fr( sm_time_ms: Union[float, list[float]], show_stim: bool = True, include_ids: list | np.ndarray | None = None, + sorted: bool = False, ): """ Function to plot smoothed firing rates @@ -469,6 +485,7 @@ def plot_sm_fr( Show lines where stim onset and offset are include_ids: list | np.ndarray | None The ids to include for plotting + sorted: bool = False, """ import matplotlib as mpl @@ -514,6 +531,12 @@ def plot_sm_fr( else: keep_list = np.arange(0, len(cluster_indices), 1) + if sorted: + sorted_indices = [] + for value_id in include_ids: + sorted_indices.append(np.nonzero(self.cluster_ids==value_id)[0][0]) + sorted_indices = np.array(sorted_indices) + stim_trial_groups = self._get_trial_groups() event_lengths = self._get_event_lengths_all() for idx, stimulus in enumerate(psths.keys()): @@ -531,6 +554,8 @@ def plot_sm_fr( trial_groups = stim_trial_groups[stimulus] sub_window = windows[idx] psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] + if sorted: + psth = psth[sorted_indices, ...] bins = bins[np.logical_and(bins > sub_window[0], bins < sub_window[1])] events = event_lengths[stimulus] tg_set = np.unique(trial_groups) @@ -590,8 +615,10 @@ def plot_sm_fr( sns.despine() else: self._despine(ax) - - plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) + if sorted: + plt.title(f"{stimulus}: {self.data.cluster_ids[sorted_indices][cluster_number]}", fontsize=8) + else: + plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() From c35b6bc2cd5bf461ca70ee2d4f880b76e24716d7 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:09:20 -0500 Subject: [PATCH 07/10] pass -> continue --- src/spikeanalysis/spike_plotter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 94952eb..fee5801 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -389,7 +389,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i if sorted: sorted_indices = [] for value_id in include_ids: - sorted_indices.append(np.nonzero(self.cluster_ids==value_id)[0][0]) + sorted_indices.append(np.nonzero(self.data.cluster_ids==value_id)[0][0]) sorted_indices = np.array(sorted_indices) for idx, stimulus in enumerate(psths.keys()): @@ -410,7 +410,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i for idy in range(np.shape(psth)[0]): if idy not in keep_list: - pass + continue psth_sub = np.squeeze(psth[idy]) if np.sum(psth_sub) == 0: @@ -534,7 +534,7 @@ def plot_sm_fr( if sorted: sorted_indices = [] for value_id in include_ids: - sorted_indices.append(np.nonzero(self.cluster_ids==value_id)[0][0]) + sorted_indices.append(np.nonzero(self.data.cluster_ids==value_id)[0][0]) sorted_indices = np.array(sorted_indices) stim_trial_groups = self._get_trial_groups() @@ -571,7 +571,7 @@ def plot_sm_fr( event_len = np.zeros((len(tg_set))) for cluster_number in range(np.shape(psth)[0]): if cluster_number not in keep_list: - pass + continue smoothed_psth = gaussian_smoothing(psth[cluster_number], bin_size, sm_std) for trial_number, trial in enumerate(tg_set): From 44c1d9eb7a48e674abb5da03457a469fddea6e0b Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Tue, 7 Nov 2023 16:17:31 -0500 Subject: [PATCH 08/10] revert sorted option. not needed --- src/spikeanalysis/spike_plotter.py | 36 ++++++------------------------ 1 file changed, 7 insertions(+), 29 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index fee5801..642c5e3 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -356,7 +356,6 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i Show lines where stim onset and offset are include_ids: list | np.ndarray | None, default: None sub ids to include - sorted: bool, default = False """ from .analysis_utils import histogram_functions as hf @@ -385,13 +384,7 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i keep_list = np.array(keep_list) else: keep_list = np.arange(0, len(cluster_indices), 1) - - if sorted: - sorted_indices = [] - for value_id in include_ids: - sorted_indices.append(np.nonzero(self.data.cluster_ids==value_id)[0][0]) - sorted_indices = np.array(sorted_indices) - + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -402,10 +395,6 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i tg_set = np.unique(trial_groups) psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] - - if sorted: - psth = psth[sorted_indices,...] - bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])] for idy in range(np.shape(psth)[0]): @@ -452,10 +441,9 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i sns.despine() else: self._despine(ax) - if sorted: - plt.title(f"{stimulus}: {self.data.cluster_ids[sorted_indices][idy]}", fontsize=8) - else: - plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) + + + plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() @@ -485,7 +473,6 @@ def plot_sm_fr( Show lines where stim onset and offset are include_ids: list | np.ndarray | None The ids to include for plotting - sorted: bool = False, """ import matplotlib as mpl @@ -531,12 +518,6 @@ def plot_sm_fr( else: keep_list = np.arange(0, len(cluster_indices), 1) - if sorted: - sorted_indices = [] - for value_id in include_ids: - sorted_indices.append(np.nonzero(self.data.cluster_ids==value_id)[0][0]) - sorted_indices = np.array(sorted_indices) - stim_trial_groups = self._get_trial_groups() event_lengths = self._get_event_lengths_all() for idx, stimulus in enumerate(psths.keys()): @@ -554,8 +535,6 @@ def plot_sm_fr( trial_groups = stim_trial_groups[stimulus] sub_window = windows[idx] psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])] - if sorted: - psth = psth[sorted_indices, ...] bins = bins[np.logical_and(bins > sub_window[0], bins < sub_window[1])] events = event_lengths[stimulus] tg_set = np.unique(trial_groups) @@ -615,10 +594,9 @@ def plot_sm_fr( sns.despine() else: self._despine(ax) - if sorted: - plt.title(f"{stimulus}: {self.data.cluster_ids[sorted_indices][cluster_number]}", fontsize=8) - else: - plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) + + + plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() From 166e20b6ea99f8f9bf4c17249454d2dcf45fa8a3 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:52:08 -0500 Subject: [PATCH 09/10] add pre-commit config to help run black automatically --- .pre-commit-config.yaml | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9cc1129 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black + files: ^src/ From 784e0100053815da273881d061c4c0d19834eb98 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:54:07 -0500 Subject: [PATCH 10/10] lint file --- src/spikeanalysis/spike_plotter.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index abaf66e..8ebc5e4 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -186,7 +186,6 @@ def _plot_scores( bar: Optional[list[int]] = None, indices: bool = False, show_stim: bool = True, - ) -> Optional[np.array]: """ Function to plot heatmaps of firing rate data @@ -212,7 +211,7 @@ def _plot_scores( if indices is True, the function will return the cluster ids as displayed in the z bar graph """ - + if data == "zscore": z_scores = self.data.z_scores elif data == "raw-data": @@ -260,7 +259,7 @@ def _plot_scores( else: RESET_INDEX = False - assert isinstance(sorting_index, (list,int)), "sorting_index must be list or int" + assert isinstance(sorting_index, (list, int)), "sorting_index must be list or int" if isinstance(sorting_index, list): current_sorting_index = sorting_index[stim_idx] else: @@ -353,7 +352,13 @@ def _plot_scores( if indices: return sorted_cluster_ids - def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, include_ids: list | np.nadarry | None = None, sorted: bool = False): + def plot_raster( + self, + window: Union[list, list[list]], + show_stim: bool = True, + include_ids: list | np.nadarry | None = None, + sorted: bool = False, + ): """ Function to plot rasters @@ -385,16 +390,15 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i event_lengths = self._get_event_lengths() - if include_ids is not None: cluster_indices = self.data.cluster_ids keep_list = [] for cid in include_ids: - keep_list.append(np.where(cluster_indices==cid)[0][0]) + keep_list.append(np.where(cluster_indices == cid)[0][0]) keep_list = np.array(keep_list) else: keep_list = np.arange(0, len(cluster_indices), 1) - + for idx, stimulus in enumerate(psths.keys()): bins = psths[stimulus]["bins"] psth = psths[stimulus]["psth"] @@ -452,7 +456,6 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True, i else: self._despine(ax) - plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show() @@ -483,7 +486,7 @@ def plot_sm_fr( Show lines where stim onset and offset are include_ids: list | np.ndarray | None The ids to include for plotting - + """ import matplotlib as mpl from .analysis_utils import histogram_functions as hf @@ -523,7 +526,7 @@ def plot_sm_fr( cluster_indices = self.data.cluster_ids keep_list = [] for cid in include_ids: - keep_list.append(np.where(cluster_indices==cid)[0][0]) + keep_list.append(np.where(cluster_indices == cid)[0][0]) keep_list = np.array(keep_list) else: keep_list = np.arange(0, len(cluster_indices), 1) @@ -605,7 +608,6 @@ def plot_sm_fr( else: self._despine(ax) - plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) plt.figure(dpi=self.dpi) plt.show()