diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 2f8f837..54db335 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -759,6 +759,7 @@ def plot_response_trace( by_trial: bool = False, ebar: bool = False, color="black", + show_stim: bool = True ): """ Function for plotting response traces for either z scored or raw firing rates @@ -787,48 +788,70 @@ def plot_response_trace( data = self.data.mean_firing_rate bins = self.data.fr_bins + stim_lengths = self._get_event_lengths() + for stimulus, response in data.items(): + current_length = stim_lengths[stimulus] current_bins = bins[stimulus] + bin_size = current_bins[1]-current_bins[0] + start_pt = np.where((current_bins>-bin_size) & (current_bins< bin_size))[0][0] + end_pt = np.where((current_bins > current_length-bin_size) & (current_bins < current_length + bin_size))[0][0] + stim_lines = [start_pt, end_pt] if by_trial and by_neuron: for neuron in range(np.shape(response)[0]): for trial in range(np.shape(response)[1]): self._plot_one_trace( - current_bins, response[neuron, trial, :], ebars=None, color=color, stim=stimulus + current_bins, response[neuron, trial, :], ebars=None, color=color, stim=stimulus, show_stim=show_stim, stim_lines=stim_lines ) elif by_neuron: for neuron in range(np.shape(response)[0]): avg_response = np.mean(response[neuron], axis=0) ebars = np.std(response[neuron], axis=0) if ebar: - self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus) + self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines) else: - self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus) + self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines) elif by_trial: for trial in range(np.shape(response)[1]): avg_response = np.mean(response[:, trial, :], axis=0) ebars = np.std(response[:, trial, :], axis=0) if ebar: - self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus) + self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines) else: - self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus) + self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines) else: avg_response = np.mean(np.mean(response, axis=1), axis=0) if ebar: - self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus) + self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines) else: - self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus) + self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines) - def _plot_one_trace(self, bins, trace, ebars=None, color="black", stim=""): + def _plot_one_trace(self, bins, trace, ebars=None, color="black", stim="", show_stim: bool= True, stim_lines: list = [0,0]): """ Function for plotting one response trace in 2D. I'm going to try to let it autoscale """ fig, ax = plt.subplots(figsize=self.figsize) - ax.plot(bins, trace, color=color) + ax.plot(bins, trace, color=color, linewdith=0.75) + max_pt = np.max(trace) if ebars is not None: - ax.plot(bins, trace + ebars, color=color) - ax.plot(bins, trace - ebars, color=color) - ax.fill_between(bins, trace - ebars, trace + ebars, color=color, alpha=0.02) + ax.plot(bins, trace + ebars, color=color, linewidth=0.25) + ax.plot(bins, trace - ebars, color=color, linewidth=0.25) + ax.fill_between(bins, trace - ebars, trace + ebars, color=color, alpha=0.4) + max_pt = np.max(trace + ebars) + if show_stim: + ax.axvline(stim_lines[0], + 0, + max_pt, + color="black", + linestyle=":", + linewidth=0.5,) + ax.axvline(stim_lines[1], + 0, + max_pt, + color="black", + linestyle=":", + linewidth=0.5,) ax.set_xlabel("Time (s)") ax.set_ylabel(self.y_axis)