diff --git a/src/spikeanalysis/spike_analysis.py b/src/spikeanalysis/spike_analysis.py index a0a813f..3bce7ed 100644 --- a/src/spikeanalysis/spike_analysis.py +++ b/src/spikeanalysis/spike_analysis.py @@ -46,7 +46,7 @@ def __repr__(self): final_vars = [current_var for current_var in var if "_" not in current_var[:2]] return f"The methods are: {final_methods} \n\n Variables are: {final_vars}" - def set_spike_data(self, sp: SpikeData): + def set_spike_data(self, sp: SpikeData, cluster_ids: np.array | list | None = None): """ loads in spike data from phy for analysis @@ -54,7 +54,8 @@ def set_spike_data(self, sp: SpikeData): ---------- sp : SpikeData A SpikeData object to analyze spike trains - + cluster_ids: np.array | list | None, default: None + If one decides to run a subset of clusters of their own choice enter here """ if self._file_path is None: @@ -71,32 +72,39 @@ def set_spike_data(self, sp: SpikeData): self.spike_times = sp.raw_spike_times / sp._sampling_rate self._cids = sp._cids - try: - self._qc_threshold = sp._qc_threshold - QC_DATA = True - except AttributeError: - if self._verbose: - print( - f"There is no qc run_threshold. Run {_possible_qc} to only\ - include acceptable values" - ) - self.qc_threshold = np.array([True for _ in self._cids]) - QC_DATA = False - - if sp.QC_RUN and QC_DATA: - sp.denoise_data() - elif QC_DATA: - sp.set_qc() - sp.denoise_data() - else: + + if cluster_ids is None: try: - sp.denoise_data() - except TypeError: + self._qc_threshold = sp._qc_threshold + QC_DATA = True + except AttributeError: if self._verbose: - print("no qc run") + print( + f"There is no qc run_threshold. Run {_possible_qc} to only\ + include acceptable values" + ) + self.qc_threshold = np.array([True for _ in self._cids]) + QC_DATA = False + + if sp.QC_RUN and QC_DATA: + sp.denoise_data() + elif QC_DATA: + sp.set_qc() + sp.denoise_data() + else: + try: + sp.denoise_data() + except TypeError: + if self._verbose: + print("no qc run") self.raw_spike_times = sp.raw_spike_times - self.cluster_ids = sp._cids + + if cluster_ids is None: + self.cluster_ids = sp._cids + else: + self.cluster_ids = np.array(cluster_ids) + self.spike_clusters = sp.spike_clusters self._sampling_rate = sp._sampling_rate diff --git a/src/spikeanalysis/utils.py b/src/spikeanalysis/utils.py index 7f656a0..fc6d4ce 100644 --- a/src/spikeanalysis/utils.py +++ b/src/spikeanalysis/utils.py @@ -3,6 +3,7 @@ from typing import Union import numpy as np from pathlib import Path +from collections import namedtuple class NumpyEncoder(json.JSONEncoder): @@ -30,6 +31,42 @@ def jsonify_parameters(parameters: dict, file_path: Path | None = None): json.dump(new_parameters, write_file) +def get_parameters(file_path: str | Path) -> namedtuple[dict]: + """ + function to read in the analysis_parameter.json. + + Parameters + ---------- + file_path: str | Path + the path to the folder containing the json + + Returns + ------- + function_kwargs: namedtuple[dict] + A namedtuple of the analysis parameters + """ + + file_path = Path(file_path) + assert file_path.is_dir(), "file_path must be the dir containing the analysis_parameters" + + with open(file_path / "analysis_parameters.json", "r") as read_file: + parameters = json.load(read_file) + + z_score = parameters.pop("z_score_data", None) + raw_firing = parameters.pop("get_raw_firing_rate", None) + psth = parameters.pop("get_raw_psth", None) + lats = parameters.pop("latencies", None) + isi = parameters.pop("compute_event_interspike_interval", None) + trial_corr = parameters.pop("trial_correlation", None) + + Functionkwargs = namedtuple( + "Functionkwargs", ["psth", "zscore", "latencies", "isi", "trial_correlations", "firing_rate"] + ) + function_kwargs = Functionkwargs(psth, z_score, lats, isi, trial_corr, raw_firing) + + return function_kwargs + + def verify_window_format(window: Union[list, list[list]], num_stim: int) -> list[list]: """Utility function for making sure window format is correct for analysis and plotting functions