Skip to content

Commit

Permalink
Merge pull request #45 from zm711/other-stim
Browse files Browse the repository at this point in the history
Create `TimestampReader` Class
  • Loading branch information
zm711 authored Nov 22, 2023
2 parents 01be7ad + 784e010 commit bbd4175
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 31 deletions.
12 changes: 12 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
files: ^src/
47 changes: 41 additions & 6 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _plot_scores(
if indices is True, the function will return the cluster ids as displayed in the z bar graph
"""

if data == "zscore":
z_scores = self.data.z_scores
elif data == "raw-data":
Expand Down Expand Up @@ -259,7 +259,7 @@ def _plot_scores(

else:
RESET_INDEX = False
assert isinstance(sorting_index, (list,int)), "sorting_index must be list or int"
assert isinstance(sorting_index, (list, int)), "sorting_index must be list or int"
if isinstance(sorting_index, list):
current_sorting_index = sorting_index[stim_idx]
else:
Expand Down Expand Up @@ -352,7 +352,13 @@ def _plot_scores(
if indices:
return sorted_cluster_ids

def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
def plot_raster(
self,
window: Union[list, list[list]],
show_stim: bool = True,
include_ids: list | np.nadarry | None = None,
sorted: bool = False,
):
"""
Function to plot rasters
Expand All @@ -363,6 +369,8 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
of [start, stop] format
show_stim: bool, default True
Show lines where stim onset and offset are
include_ids: list | np.ndarray | None, default: None
sub ids to include
"""
from .analysis_utils import histogram_functions as hf

Expand All @@ -382,6 +390,15 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):

event_lengths = self._get_event_lengths()

if include_ids is not None:
cluster_indices = self.data.cluster_ids
keep_list = []
for cid in include_ids:
keep_list.append(np.where(cluster_indices == cid)[0][0])
keep_list = np.array(keep_list)
else:
keep_list = np.arange(0, len(cluster_indices), 1)

for idx, stimulus in enumerate(psths.keys()):
bins = psths[stimulus]["bins"]
psth = psths[stimulus]["psth"]
Expand All @@ -394,8 +411,10 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
psth = psth[:, :, np.logical_and(bins > sub_window[0], bins < sub_window[1])]
bins = bins[np.logical_and(bins >= sub_window[0], bins <= sub_window[1])]

for idx in range(np.shape(psth)[0]):
psth_sub = np.squeeze(psth[idx])
for idy in range(np.shape(psth)[0]):
if idy not in keep_list:
continue
psth_sub = np.squeeze(psth[idy])

if np.sum(psth_sub) == 0:
continue
Expand Down Expand Up @@ -436,7 +455,8 @@ def plot_raster(self, window: Union[list, list[list]], show_stim: bool = True):
sns.despine()
else:
self._despine(ax)
plt.title(f"{self.data.cluster_ids[idx]} stim: {stimulus}", size=7)

plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8)
plt.figure(dpi=self.dpi)
plt.show()

Expand All @@ -446,6 +466,8 @@ def plot_sm_fr(
time_bin_ms: Union[float, list[float]],
sm_time_ms: Union[float, list[float]],
show_stim: bool = True,
include_ids: list | np.ndarray | None = None,
sorted: bool = False,
):
"""
Function to plot smoothed firing rates
Expand All @@ -462,6 +484,8 @@ def plot_sm_fr(
stimulus
show_stim: bool, default True
Show lines where stim onset and offset are
include_ids: list | np.ndarray | None
The ids to include for plotting
"""
import matplotlib as mpl
Expand Down Expand Up @@ -498,6 +522,15 @@ def plot_sm_fr(
number of bins is{len(time_bin_ms)} and should be {NUM_STIM}"
time_bin_size = np.array(time_bin_ms) / 1000

if include_ids is not None:
cluster_indices = self.data.cluster_ids
keep_list = []
for cid in include_ids:
keep_list.append(np.where(cluster_indices == cid)[0][0])
keep_list = np.array(keep_list)
else:
keep_list = np.arange(0, len(cluster_indices), 1)

stim_trial_groups = self._get_trial_groups()
event_lengths = self._get_event_lengths_all()
for idx, stimulus in enumerate(psths.keys()):
Expand Down Expand Up @@ -529,6 +562,8 @@ def plot_sm_fr(
stderr = np.zeros((len(tg_set), len(bins)))
event_len = np.zeros((len(tg_set)))
for cluster_number in range(np.shape(psth)[0]):
if cluster_number not in keep_list:
continue
smoothed_psth = gaussian_smoothing(psth[cluster_number], bin_size, sm_std)

for trial_number, trial in enumerate(tg_set):
Expand Down
174 changes: 155 additions & 19 deletions src/spikeanalysis/stimulus_data.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from __future__ import annotations
import json
import os
from .utils import NumpyEncoder
from typing import Optional, Union
from pathlib import Path
import warnings

import neo
import numpy as np

from tqdm import tqdm

from .utils import NumpyEncoder


class StimulusData:
"""Class for preprocessing stimulus data for spike train analysis"""

def __init__(self, file_path: str):
"""Enter the file_path as a string. For Windows prepend with r to prevent spurious escaping.
A Path object can also be given, but make sure it was generated with a raw string"""
from pathlib import Path

import glob
import os

Expand Down Expand Up @@ -75,10 +76,10 @@ def get_all_files(self):

def run_all(
self,
stim_index: Optional[int] = None,
stim_length_seconds: Optional[float] = None,
stim_name: Optional[list] = None,
time_slice=(None, None),
stim_index: int | None = None,
stim_length_seconds: float | None = None,
stim_name: list | None = None,
time_slice: tuple = (None, None),
):
"""
Pipeline function to run through all steps necessary to load intan data
Expand Down Expand Up @@ -128,7 +129,7 @@ def run_all(

del self.reader # reader and memmap heavy. Delete after this since not needed

def create_neo_reader(self, file_name: Optional[str] = None):
def create_neo_reader(self, file_name: str | Path | None = None):
"""
Function that creates a Neo IntanRawIO reader and then parses the header
Expand Down Expand Up @@ -156,6 +157,8 @@ def create_neo_reader(self, file_name: Optional[str] = None):
sample_freq = value[2]
break
self.sample_frequency = sample_freq

self.start_timestamp = reader._raw_data["timestamp"].flatten()[0]
self.reader = reader

def get_analog_data(self, time_slice: tuple = (None, None)):
Expand Down Expand Up @@ -200,9 +203,9 @@ def get_analog_data(self, time_slice: tuple = (None, None)):

def digitize_analog_data(
self,
analog_index: Optional[int] = None,
stim_length_seconds: Optional[float] = None,
stim_name: Optional[list[str]] = None,
analog_index: int | None = None,
stim_length_seconds: float | None = None,
stim_name: list[str] | None = None,
):
"""Function to digitize the analog data for stimuli that have "events" rather than assessing
them as continually changing values"""
Expand Down Expand Up @@ -410,9 +413,9 @@ def set_stimulus_name(self, stim_names: dict):

def generate_stimulus_trains(
self,
channel_name: Union[str, list[str]],
stim_freq: Union[float, list[float]],
stim_time_secs: Union[float, list[float]],
channel_name: str | list[str],
stim_freq: float | list[float],
stim_time_secs: float | list[float],
):
"""
Function for converting events into event trains, eg for optogenetic stimulus trains
Expand Down Expand Up @@ -460,7 +463,7 @@ def save_events(self):
for dig_channel, event_type in digital_events.items():
assert (
"stim" in event_type.keys()
), f"Mst provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}"
), f"Must provide name for each stim using the the set_stimulus_name() function. Please do this for {dig_channel}"
with open("digital_events.json", "w") as write_file:
json.dump(self.digital_events, write_file, cls=NumpyEncoder)
except AttributeError:
Expand All @@ -478,8 +481,8 @@ def delete_events(
self,
del_index: int | list[int],
digital: bool = True,
channel_name: Optional[str] = None,
channel_index: Optional[int] = None,
channel_name: str | None = None,
channel_index: str | None = None,
):
"""
Function for deleting a spurious event, eg, an accident trigger event
Expand Down Expand Up @@ -517,7 +520,7 @@ def delete_events(
else:
self.dig_analog_events[key] = data_to_clean

def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice=(None, None)) -> np.array:
def _intan_neo_read_no_dig(self, reader: neo.rawio.IntanRawIO, time_slice: tuple = (None, None)) -> np.array:
"""
Utility function that hacks the Neo memmap structure to be able to read
digital events.
Expand Down Expand Up @@ -587,3 +590,136 @@ def _calculate_events(self, array: np.array) -> tuple[np.array, np.array]:
lengths = offset - onset

return onset, lengths


class TimestampReader:

"""utility class for helping load non-synced timestamp based data with leading-edge falling-edge."""

def __init__(
self,
data: list | np.ndarray,
timestamps: list | np.ndarray,
start_timestamp: float = 0.0,
sample_rate: int | None = None,
):
"""
Parameters
----------
data: list | np.ndarray
An array containing the TTL style data of 0s and some int
timestamps: list | np.ndarray
A timestamp for each value given in data
start_timestamp: float, default: 0.0
The starting timestamp to sync the data to a sample time scale
sample_rate int | None, default: None
The sample rate to convert from time into samples"""

self.data = np.array(data)
self.timestamps = np.array(timestamps)
self._start_timestamp = start_timestamp
self._sample_rate = sample_rate

def set_start_timestamp(self, start_ts: float | StimulusData):
"""
Function to set the timestamp offset
Parameters
----------
start_ts: float | StimulusData
The start timestamp to offset the analysis with"""

if isinstance(start_ts, (float, int)):
self._start_timestamp = float(start_ts)
elif isinstance(start_ts, StimulusData):
self._start_timestamp = start_ts.start_timestamp
else:
raise TypeError(f"`start_ts` must be float or StimulusData. It is of type {type(start_ts)}")

def set_sample_rate(self, sample_rate: int | StimulusData):
"""
Function to set the sample rate
Parameters
----------
sample_rate: int | StimulusData
The sample rate to convert from time to samples"""

if isinstance(sample_rate, (float, int)):
self._sample_rate = sample_rate
elif isinstance(sample_rate, StimulusData):
self._sample_rate = sample_rate.sample_frequency
else:
raise TypeError(f"`start_ts` must be int or StimulusData. It is of type {type(sample_rate)}")

def load_into_stimulus_data(self, stim: StimulusData, new_stim_key: str, in_place: bool = True):
"""Function which loads a timestamp TTL into StimulusData
Parameters
----------
stim: StimulusData
The StimulusData object to use
new_stim_key: str
The key value to use in the `digital_events` dictionary
in_place: bool, default=True
If true loads the new events into the current StimulusData
If false returns a deep copy with the new data loaded
Returns
-------
stim1: StimulusData
If in_place set to false it returns a deepcopy of the StimulusData with
the new events loaded"""

assert isinstance(stim, StimulusData), "function is for loading into StimulusData"
try:
assert (
new_stim_key not in stim.digital_events
), f"`new_stim_key` must be new key current keys are {stim.digital_events.keys()}"
except AttributeError:
warnings.warn(
"This function should be run after all other stimulus data has been processed but before setting trial groups and names"
)
stim.digital_events = {}

onsets, lengths = self._calculate_events()

if in_place:
stim.digital_events[new_stim_key] = {}
stim.digital_events[new_stim_key]["onsets"] = onsets
stim.digital_events[new_stim_key]["lengths"] = lengths
stim.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets)))
else:
import copy

stim1 = copy.deepcopy(stim)
stim1.digital_events[new_stim_key] = {}
stim1.digital_events[new_stim_key]["onsets"] = onsets
stim1.digital_events[new_stim_key]["lengths"] = lengths
stim1.digital_events[new_stim_key]["trial_groups"] = np.ones((len(onsets)))
return stim1

def _calculate_events(self) -> tuple[np.ndarray, np.ndarray]:
"""Function to convert from timestamps to samples as well as a leading/falling edge detector
Returns
-------
onset_samples: np.ndarray
The onset of events in samples
lengths: np.ndarray
the lengths of the events in samples"""

assert self._sample_rate, "`sample_rate` must be set to calculate events, use `set_sample_rate()`"

timestamps = self.timestamps - self._start_timestamp
onset = np.where(np.diff(self.data) < 0)[0]
offset = np.where(np.diff(self.data) > 0)[0]
if self.data[0] > 0:
onset = np.pad(onset, (1, 0), "constant", constant_values=0)
if self.data[-1] > 0:
offset = np.pad(offset, (0, 1), "constant", constant_value=self.data[-1])

onset_timestamps = timestamps[onset]
offset_timestamps = timestamps[offset]

onset_samples = onset_timestamps * self._sample_rate
offset_samples = offset_timestamps * self._sample_rate

lengths = onset_samples - offset_samples

return onset_samples, lengths
Loading

0 comments on commit bbd4175

Please sign in to comment.