diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 9269f0e..7800acb 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -811,6 +811,7 @@ def plot_response_trace( ebar: bool = False, color="black", show_stim: bool = True, + mode: Literal['mean', 'median', 'max', 'min'] = 'mean' ): """ Function for plotting response traces for either z scored or raw firing rates @@ -829,6 +830,8 @@ def plot_response_trace( Color to plot the traces in show_stim: bool, default=True Whether to show stimulus lines + mode: 'mean'| 'median' | 'max' | 'min', deafult: 'mean' + How to calculate values for plotting """ @@ -843,6 +846,17 @@ def plot_response_trace( stim_lengths = self._get_event_lengths() + assert mode in ('mean', 'median', 'max', 'min'), f"mode must be 'mean' 'median', 'max', 'min you entered {mode}" + + if mode=='mean': + func = np.nanmean + elif mode == 'median': + func = np.nanmedian + elif mode == 'max': + func = np.nanmax + else: + func = np.nanmin + for stimulus, response in data.items(): current_length = stim_lengths[stimulus] current_bins = bins[stimulus] @@ -863,7 +877,7 @@ def plot_response_trace( ) elif by_neuron: for neuron in range(np.shape(response)[0]): - avg_response = np.nanmean(response[neuron], axis=0) + avg_response = func(response[neuron], axis=0) ebars = np.nanstd(response[neuron], axis=0) if ebar: self._plot_one_trace( @@ -887,7 +901,7 @@ def plot_response_trace( ) elif by_trial: for trial in range(np.shape(response)[1]): - avg_response = np.nanmean(response[:, trial, :], axis=0) + avg_response = func(response[:, trial, :], axis=0) ebars = np.nanstd(response[:, trial, :], axis=0) if ebar: self._plot_one_trace( @@ -910,8 +924,8 @@ def plot_response_trace( stim_lines=current_length, ) else: - avg_response = np.nanmean(np.nanmean(response, axis=1), axis=0) - ebars = np.nanstd(np.nanmean(response, axis=1), axis=0) + avg_response = np.mean(func(response, axis=1), axis=0) + ebars = np.nanstd(func(response, axis=1), axis=0) if ebar: self._plot_one_trace( current_bins,