Skip to content

Commit

Permalink
fix argument type -> fr_type
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Oct 4, 2023
1 parent 0b25bde commit 0b6577e
Showing 1 changed file with 96 additions and 30 deletions.
126 changes: 96 additions & 30 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def __init__(self, analysis: Optional[SpikeAnalysis | CuratedSpikeAnalysis] = No
self._set_kwargs(**kwargs)

if analysis is not None:
assert isinstance(analysis, (SpikeAnalysis, CuratedSpikeAnalysis)), "analysis must be a SpikeAnalysis dataset"
assert isinstance(
analysis, (SpikeAnalysis, CuratedSpikeAnalysis)
), "analysis must be a SpikeAnalysis dataset"
self.data = analysis

def set_kwargs(self, **kwargs):
Expand Down Expand Up @@ -754,19 +756,19 @@ def plot_event_isi(self):

def plot_response_trace(
self,
type: Literal["zscore", "raw"] = "zscore",
fr_type: Literal["zscore", "raw"] = "zscore",
by_neuron: bool = False,
by_trial: bool = False,
ebar: bool = False,
color="black",
show_stim: bool = True
show_stim: bool = True,
):
"""
Function for plotting response traces for either z scored or raw firing rates
Parameters
----------
type: Literal['zscore', 'raw'], default: 'zscore'
fr_type: Literal['zscore', 'raw'], default: 'zscore'
Whether to generate traces with zscored data or raw firing rate data
by_neuron: bool, default: False
Whether to plot each neuron separate (True) or average over all neurons (False)
Expand All @@ -776,15 +778,17 @@ def plot_response_trace(
Whether to include error bars in the traces
color: matplotlib color, default: 'black'
Color to plot the traces in
show_stim: bool, default=True
Whether to show stimulus lines
"""

assert type in ["zscore", "raw"], "type of data must be zscore or raw"
assert fr_type in ["zscore", "raw"], f"fr_type of data must be zscore or raw, you entered {fr_type}"

if type == "zscore":
if fr_type == "zscore":
data = self.data.z_scores
bins = self.data.z_bins
elif type == "raw":
elif fr_type == "raw":
data = self.data.mean_firing_rate
bins = self.data.fr_bins

Expand All @@ -793,40 +797,98 @@ def plot_response_trace(
for stimulus, response in data.items():
current_length = stim_lengths[stimulus]
current_bins = bins[stimulus]
bin_size = current_bins[1]-current_bins[0]
start_pt = np.where((current_bins>-bin_size) & (current_bins< bin_size))[0][0]
end_pt = np.where((current_bins > current_length-bin_size) & (current_bins < current_length + bin_size))[0][0]
bin_size = current_bins[1] - current_bins[0]
start_pt = np.where((current_bins > -bin_size) & (current_bins < bin_size))[0][0]
end_pt = np.where((current_bins > current_length - bin_size) & (current_bins < current_length + bin_size))[
0
][0]
stim_lines = [start_pt, end_pt]
if by_trial and by_neuron:
for neuron in range(np.shape(response)[0]):
for trial in range(np.shape(response)[1]):
self._plot_one_trace(
current_bins, response[neuron, trial, :], ebars=None, color=color, stim=stimulus, show_stim=show_stim, stim_lines=stim_lines
current_bins,
response[neuron, trial, :],
ebars=None,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)
elif by_neuron:
for neuron in range(np.shape(response)[0]):
avg_response = np.mean(response[neuron], axis=0)
ebars = np.std(response[neuron], axis=0)
if ebar:
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
self._plot_one_trace(
current_bins,
avg_response,
ebars=ebars,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)
else:
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
self._plot_one_trace(
current_bins,
avg_response,
ebars=None,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)
elif by_trial:
for trial in range(np.shape(response)[1]):
avg_response = np.mean(response[:, trial, :], axis=0)
ebars = np.std(response[:, trial, :], axis=0)
if ebar:
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
self._plot_one_trace(
current_bins,
avg_response,
ebars=ebars,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)
else:
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
self._plot_one_trace(
current_bins,
avg_response,
ebars=None,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)
else:
avg_response = np.mean(np.mean(response, axis=1), axis=0)
if ebar:
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
self._plot_one_trace(
current_bins,
avg_response,
ebars=ebars,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)
else:
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
self._plot_one_trace(
current_bins,
avg_response,
ebars=None,
color=color,
stim=stimulus,
show_stim=show_stim,
stim_lines=stim_lines,
)

def _plot_one_trace(self, bins, trace, ebars=None, color="black", stim="", show_stim: bool= True, stim_lines: list = [0,0]):
def _plot_one_trace(
self, bins, trace, ebars=None, color="black", stim="", show_stim: bool = True, stim_lines: list = [0, 0]
):
"""
Function for plotting one response trace in 2D. I'm going to try
to let it autoscale
Expand All @@ -840,18 +902,22 @@ def _plot_one_trace(self, bins, trace, ebars=None, color="black", stim="", show_
ax.fill_between(bins, trace - ebars, trace + ebars, color=color, alpha=0.4)
max_pt = np.max(trace + ebars)
if show_stim:
ax.axvline(stim_lines[0],
0,
max_pt,
color="black",
linestyle=":",
linewidth=0.5,)
ax.axvline(stim_lines[1],
0,
max_pt,
color="black",
linestyle=":",
linewidth=0.5,)
ax.axvline(
stim_lines[0],
0,
max_pt,
color="black",
linestyle=":",
linewidth=0.5,
)
ax.axvline(
stim_lines[1],
0,
max_pt,
color="black",
linestyle=":",
linewidth=0.5,
)

ax.set_xlabel("Time (s)")
ax.set_ylabel(self.y_axis)
Expand Down

0 comments on commit 0b6577e

Please sign in to comment.