Skip to content

Commit

Permalink
update correlations and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Nov 30, 2023
1 parent 5a8bb4b commit 292a373
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
38 changes: 26 additions & 12 deletions src/spikeanalysis/spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,11 @@ def compute_event_interspike_intervals(self, time_ms: float = 200):
self.isi_values = raw_data

def trial_correlation(
self, window: Union[list, list[list]], time_bin_ms: Optional[float] = None, dataset: str = "psth"
self,
window: Union[list, list[list]],
time_bin_ms: Optional[float] = None,
dataset: "psth" | "raw" | "z_scores" = "psth",
method: "pearson" | "kendall" | "spearman" = "pearson",
):
"""
Function to calculate pairwise pearson correlation coefficents of z scored or raw firing rate data/time bin.
Expand All @@ -683,9 +687,10 @@ def trial_correlation(
time_bin_ms : float, optional
Size of time bins to use given in milliseconds. Bigger time bins smooths the data which can remove some
artificial differences in trials.
dataset : str, (psth, z_scores)
Whether to use the psth (raw spike counts) or z_scored data. The default is 'z_scores'.
dataset : "psth" | "raw" | "z_scores", default: "psth"
Whether to use the psth (raw spike counts) raw (the firing rates) or z_scored data.
method: "pearson", "kendall", "spearman", default: "pearson"
the correlation method to be used in the pandas.DataFrame.corr() function
Raises
------
Exception
Expand All @@ -698,7 +703,7 @@ def trial_correlation(
"""

if self._save_params:
parameters = {"trial_correlation": dict(time_bin_ms=time_bin_ms, dataset=dataset)}
parameters = {"trial_correlation": dict(time_bin_ms=time_bin_ms, dataset=dataset, method=method)}
jsonify_parameters(parameters, self._file_path)

try:
Expand All @@ -714,6 +719,15 @@ def trial_correlation(
except AttributeError:
raise Exception("To run dataset=='psth', ensure 'get_raw_psth' has been run")

elif dataset == "raw":
try:
raw_firing = self.raw_firing_rate
data = raw_firing
except AttributeError:
raise AttributeError(
'To run dataset="raw" ensure "get_raw_psth" and "get_raw_firing_rate" have been run'
)

elif dataset == "z_scores":
try:
z_scores = self.raw_zscores
Expand All @@ -723,7 +737,7 @@ def trial_correlation(
raise Exception("To run dataset=='z_scores', ensure ('get_raw_psth', 'z_score_data')")

else:
raise Exception(f"You have entered {dataset} and only ('psth', or 'z_scores') are possible options")
raise Exception(f"You have entered {dataset} and only ('psth', 'z_scores', or 'raw') are possible options")

windows = verify_window_format(window=window, num_stim=self.NUM_STIM)
if time_bin_ms is not None:
Expand Down Expand Up @@ -782,7 +796,7 @@ def trial_correlation(
final_sub_data = np.squeeze(current_data_windowed_by_trial[cluster_number])
data_dataframe = pd.DataFrame(np.squeeze(final_sub_data.T))

sub_correlations = data_dataframe.corr()
sub_correlations = data_dataframe.corr(method=method)
masked_correlations = sub_correlations[sub_correlations != 1]
for row in range(np.shape(masked_correlations)[0]):
final_correlations = np.nanmean(masked_correlations.iloc[row, :])
Expand Down Expand Up @@ -838,7 +852,8 @@ def _generate_sample_z_parameter(self) -> dict:

return example_z_parameter

def save_z_sample_parameters(self, z_parameters: dict):
def save_z_parameters(self, z_parameters: dict):
"""saves the z parameters to be used in the future"""
import json

with open("z_parameters.json", "w") as write_file:
Expand Down Expand Up @@ -881,7 +896,8 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None):
with open("z_parameters.json") as read_file:
z_parameters = json.load(read_file)
else:
z_parameters = z_parameters
if not isinstance(z_parameters, dict):
raise TypeError(f"z_parameters must be of type dict, but is of type: {type(z_parameters)}")

if "all" in z_parameters.keys():
SAME_PARAMS = True
Expand All @@ -896,7 +912,6 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None):

if SAME_PARAMS:
current_z_params = z_parameters["all"]

else:
current_z_params = z_parameters[stim]

Expand All @@ -923,7 +938,6 @@ def get_responsive_neurons(self, z_parameters: Optional[dict] = None):
z_above_threshold = np.sum(np.where(current_z_scores_sub < current_score, 1, 0), axis=2)

responsive_neurons = np.where(z_above_threshold > current_n_bins, True, False)

self.responsive_neurons[stim][key] = responsive_neurons

def save_responsive_neurons(self):
Expand All @@ -934,7 +948,7 @@ def save_responsive_neurons(self):
with open(file_path / "response_profile.json", "w") as write_file:
json.dump(self.responsive_neurons, write_file, cls=NumpyEncoder)

def _merge_events(self, event_0: dict, event_1: dict):
def _merge_events(self, event_0: dict, event_1: dict) -> dict:
"""Utility function for merging digital and analog events into one dictionary"""
events = {**event_0, **event_1}
return events
Expand Down
2 changes: 1 addition & 1 deletion src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class SpikePlotter(PlotterBase):
"""SpikePlotter is a plotting class which allows for plotting of PSTHs, z score heatmaps
in the future it will plot other values"""

def __init__(self, analysis: Optional[SpikeAnalysis | CuratedSpikeAnalysis] = None, **kwargs):
def __init__(self, analysis: Optional[SpikeAnalysis | CuratedSpikeAnalysis | MergedSpikeAnalysis] = None, **kwargs):
"""
SpikePlotter requires a SpikeAnalysis object, which can be set during init
or in the set_analysis function. Not including the SpikeAnalysis object
Expand Down

0 comments on commit 292a373

Please sign in to comment.