Skip to content

Commit

Permalink
allow for mode in plot_response
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Nov 27, 2023
1 parent 1a47ba0 commit 6c3b4eb
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,6 +811,7 @@ def plot_response_trace(
ebar: bool = False,
color="black",
show_stim: bool = True,
mode: Literal['mean', 'median', 'max', 'min'] = 'mean'
):
"""
Function for plotting response traces for either z scored or raw firing rates
Expand All @@ -829,6 +830,8 @@ def plot_response_trace(
Color to plot the traces in
show_stim: bool, default=True
Whether to show stimulus lines
mode: 'mean'| 'median' | 'max' | 'min', deafult: 'mean'
How to calculate values for plotting
"""

Expand All @@ -843,6 +846,17 @@ def plot_response_trace(

stim_lengths = self._get_event_lengths()

assert mode in ('mean', 'median', 'max', 'min'), f"mode must be 'mean' 'median', 'max', 'min you entered {mode}"

if mode=='mean':
func = np.nanmean
elif mode == 'median':
func = np.nanmedian
elif mode == 'max':
func = np.nanmax
else:
func = np.nanmin

for stimulus, response in data.items():
current_length = stim_lengths[stimulus]
current_bins = bins[stimulus]
Expand All @@ -863,7 +877,7 @@ def plot_response_trace(
)
elif by_neuron:
for neuron in range(np.shape(response)[0]):
avg_response = np.nanmean(response[neuron], axis=0)
avg_response = func(response[neuron], axis=0)
ebars = np.nanstd(response[neuron], axis=0)
if ebar:
self._plot_one_trace(
Expand All @@ -887,7 +901,7 @@ def plot_response_trace(
)
elif by_trial:
for trial in range(np.shape(response)[1]):
avg_response = np.nanmean(response[:, trial, :], axis=0)
avg_response = func(response[:, trial, :], axis=0)
ebars = np.nanstd(response[:, trial, :], axis=0)
if ebar:
self._plot_one_trace(
Expand All @@ -910,8 +924,8 @@ def plot_response_trace(
stim_lines=current_length,
)
else:
avg_response = np.nanmean(np.nanmean(response, axis=1), axis=0)
ebars = np.nanstd(np.nanmean(response, axis=1), axis=0)
avg_response = np.mean(func(response, axis=1), axis=0)
ebars = np.nanstd(func(response, axis=1), axis=0)
if ebar:
self._plot_one_trace(
current_bins,
Expand Down

0 comments on commit 6c3b4eb

Please sign in to comment.