diff --git a/src/spikeanalysis/spike_analysis.py b/src/spikeanalysis/spike_analysis.py index 081dde9..73f2a69 100644 --- a/src/spikeanalysis/spike_analysis.py +++ b/src/spikeanalysis/spike_analysis.py @@ -1,4 +1,5 @@ from __future__ import annotations +import json from typing import Union, Optional import numpy as np @@ -8,7 +9,7 @@ from .stimulus_data import StimulusData from .analysis_utils import histogram_functions as hf from .analysis_utils import latency_functions as lf -from .utils import verify_window_format, gaussian_smoothing, NumpyEncoder +from .utils import verify_window_format, gaussian_smoothing, NumpyEncoder, jsonify_parameters _possible_digital = ("generate_digital_events", "set_trial_groups", "set_stimulus_name") @@ -19,16 +20,17 @@ class SpikeAnalysis: """Class for spike train analysis utilizing a SpikeData object and a StimulusData object""" - def __init__(self): + def __init__(self, save_parameters: bool = False): self._file_path = None self.events = {} + self._save_params = save_parameters def __repr__(self): var_methods = dir(self) var = list(vars(self).keys()) # get our currents variables methods = list(set(var_methods) - set(var)) final_methods = [method for method in methods if "__" not in method and method[0] != "_"] - final_vars = [current_var for current_var in var if "_" not in current_var] + final_vars = [current_var for current_var in var if "_" not in current_var[:2]] return f"The methods are: {final_methods} Variables are: {final_vars}" def set_spike_data(self, sp: SpikeData): @@ -162,6 +164,10 @@ def get_raw_psth( """ + if self._save_params: + parameters = {"get_raw_psth": dict(window=window, time_bin_ms=time_bin_ms)} + jsonify_parameters(parameters, self._file_path) + spike_times = self.raw_spike_times spike_clusters = self.spike_clusters cluster_ids = self.cluster_ids @@ -259,6 +265,18 @@ def get_raw_firing_rate( except AttributeError: raise Exception("Run get_raw_psth before running z_score_data") + if self._save_params: + parameters = { + "get_raw_firing_rate": dict( + time_bin_ms=time_bin_ms, + fr_window=fr_window, + mode=mode, + bsl_window=bsl_window, + sm_time_ms=sm_time_ms, + ) + } + jsonify_parameters(parameters, self._file_path) + stim_dict = self._get_key_for_stim() NUM_STIM = self.NUM_STIM @@ -356,6 +374,7 @@ def z_score_data( time_bin_ms: Union[list[float], float], bsl_window: Union[list, list[list]], z_window: Union[list, list[list]], + eps: float = 0, ): """ z scores data the psth data @@ -373,6 +392,8 @@ def z_score_data( The event window for finding the z scores/time_bin. Either a single sequence of (start, end) in relation to stim onset at 0 applied for all stim. Or a list of lists where each stimulus has its own (start, end) + eps: float, default: 0 + Value to prevent nans from occurring during z-scoring Raises ------ @@ -389,6 +410,10 @@ def z_score_data( except AttributeError: raise Exception("Run get_raw_psth before running z_score_data") + if self._save_params: + parameters = {"z_score_data": dict(time_bin_ms=time_bin_ms, bsl_window=bsl_window, z_window=z_window)} + jsonify_parameters(parameters, self._file_path) + stim_dict = self._get_key_for_stim() NUM_STIM = self.NUM_STIM @@ -441,7 +466,8 @@ def z_score_data( for trial_number, trial in enumerate(tqdm(trial_set)): bsl_trial = bsl_psth[:, trials == trial, :] mean_fr = np.mean(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0])) - std_fr = np.std(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0])) + # for future computations may be beneficial to have small eps to std to prevent divide by 0 + std_fr = np.std(np.sum(bsl_trial, axis=2), axis=1) / ((bsl_current[1] - bsl_current[0])) + eps z_trial = z_psth[:, trials == trial, :] / time_bin_current z_trials = hf.z_score_values(z_trial, mean_fr, std_fr) z_scores[stim][:, trials == trial, :] = z_trials[:, :, :] @@ -471,6 +497,10 @@ def latencies(self, bsl_window: Union[list, list[float]], time_bin_ms: float = 5 """ + if self._save_params: + parameters = {"latencies": dict(bsl_window=bsl_window, time_bin_ms=time_bin_ms, num_shuffles=num_shuffles)} + jsonify_parameters(parameters, self._file_path) + NUM_STIM = self.NUM_STIM self._latency_time_bin = time_bin_ms bsl_windows = verify_window_format(window=bsl_window, num_stim=NUM_STIM) @@ -591,6 +621,11 @@ def compute_event_interspike_intervals(self, time_ms: float = 200): None. """ + + if self._save_params: + parameters = {"compute_event_interspike_interval": dict(time_ms=time_ms)} + jsonify_parameters(parameters, self._file_path) + bins = np.linspace(0, time_ms / 1000, num=int(time_ms + 1)) final_isi = {} raw_data = {} @@ -662,7 +697,10 @@ def trial_correlation( """ - assert dataset == "psth", "z-score is wip please only use psth for now" + if self._save_params: + parameters = {"trial_correlation": dict(time_bin_ms=time_bin_ms, dataset=dataset)} + jsonify_parameters(parameters, self._file_path) + try: import pandas as pd except ImportError: @@ -678,7 +716,7 @@ def trial_correlation( elif dataset == "z_scores": try: - z_scores = self.z_scores + z_scores = self.raw_zscores data = z_scores bins = self.z_bins except AttributeError: @@ -800,6 +838,12 @@ def _generate_sample_z_parameter(self) -> dict: return example_z_parameter + def save_z_sample_parameters(self, z_parameters: dict): + import json + + with open("z_parameters.json", "w") as write_file: + json.dump(z_parameters, write_file) + def get_responsive_neurons(self, z_parameters: Optional[dict] = None): """ function for assessing only responsive neurons based on z scored parameters. @@ -833,7 +877,7 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None): or dict of response properties in same format " ) - if len(parameter_file) > 0: + if z_parameters is None: with open("z_parameters.json") as read_file: z_parameters = json.load(read_file) else: diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 4b7a00e..7f656a0 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -2,6 +2,7 @@ import json from typing import Union import numpy as np +from pathlib import Path class NumpyEncoder(json.JSONEncoder): @@ -11,9 +12,13 @@ def default(self, obj): return json.JSONEncoder.default(self, obj) -def jsonify_parameters(parameters: dict): +def jsonify_parameters(parameters: dict, file_path: Path | None = None): + if file_path is not None: + assert file_path.exists() + else: + file_path = Path("") try: - with open("analysis_parameters.json", "r") as read_file: + with open(file_path / "analysis_parameters.json", "r") as read_file: old_params = json.load(read_file) old_params.update(parameters) new_parameters = old_params @@ -21,7 +26,7 @@ def jsonify_parameters(parameters: dict): except FileNotFoundError: new_parameters = parameters - with open("analysis_parameters.json", "w") as write_file: + with open(file_path / "analysis_parameters.json", "w") as write_file: json.dump(new_parameters, write_file)