Skip to content

Commit

Permalink
make traces prettier
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Oct 4, 2023
1 parent ce38e36 commit 0b25bde
Showing 1 changed file with 35 additions and 12 deletions.
47 changes: 35 additions & 12 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def plot_response_trace(
by_trial: bool = False,
ebar: bool = False,
color="black",
show_stim: bool = True
):
"""
Function for plotting response traces for either z scored or raw firing rates
Expand Down Expand Up @@ -787,48 +788,70 @@ def plot_response_trace(
data = self.data.mean_firing_rate
bins = self.data.fr_bins

stim_lengths = self._get_event_lengths()

for stimulus, response in data.items():
current_length = stim_lengths[stimulus]
current_bins = bins[stimulus]
bin_size = current_bins[1]-current_bins[0]
start_pt = np.where((current_bins>-bin_size) & (current_bins< bin_size))[0][0]
end_pt = np.where((current_bins > current_length-bin_size) & (current_bins < current_length + bin_size))[0][0]
stim_lines = [start_pt, end_pt]
if by_trial and by_neuron:
for neuron in range(np.shape(response)[0]):
for trial in range(np.shape(response)[1]):
self._plot_one_trace(
current_bins, response[neuron, trial, :], ebars=None, color=color, stim=stimulus
current_bins, response[neuron, trial, :], ebars=None, color=color, stim=stimulus, show_stim=show_stim, stim_lines=stim_lines
)
elif by_neuron:
for neuron in range(np.shape(response)[0]):
avg_response = np.mean(response[neuron], axis=0)
ebars = np.std(response[neuron], axis=0)
if ebar:
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus)
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
else:
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus)
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
elif by_trial:
for trial in range(np.shape(response)[1]):
avg_response = np.mean(response[:, trial, :], axis=0)
ebars = np.std(response[:, trial, :], axis=0)
if ebar:
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus)
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
else:
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus)
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
else:
avg_response = np.mean(np.mean(response, axis=1), axis=0)
if ebar:
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus)
self._plot_one_trace(current_bins, avg_response, ebars=ebars, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)
else:
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus)
self._plot_one_trace(current_bins, avg_response, ebars=None, color=color, stim=stimulus,show_stim=show_stim, stim_lines=stim_lines)

def _plot_one_trace(self, bins, trace, ebars=None, color="black", stim=""):
def _plot_one_trace(self, bins, trace, ebars=None, color="black", stim="", show_stim: bool= True, stim_lines: list = [0,0]):
"""
Function for plotting one response trace in 2D. I'm going to try
to let it autoscale
"""
fig, ax = plt.subplots(figsize=self.figsize)
ax.plot(bins, trace, color=color)
ax.plot(bins, trace, color=color, linewdith=0.75)
max_pt = np.max(trace)
if ebars is not None:
ax.plot(bins, trace + ebars, color=color)
ax.plot(bins, trace - ebars, color=color)
ax.fill_between(bins, trace - ebars, trace + ebars, color=color, alpha=0.02)
ax.plot(bins, trace + ebars, color=color, linewidth=0.25)
ax.plot(bins, trace - ebars, color=color, linewidth=0.25)
ax.fill_between(bins, trace - ebars, trace + ebars, color=color, alpha=0.4)
max_pt = np.max(trace + ebars)
if show_stim:
ax.axvline(stim_lines[0],
0,
max_pt,
color="black",
linestyle=":",
linewidth=0.5,)
ax.axvline(stim_lines[1],
0,
max_pt,
color="black",
linestyle=":",
linewidth=0.5,)

ax.set_xlabel("Time (s)")
ax.set_ylabel(self.y_axis)
Expand Down

0 comments on commit 0b25bde

Please sign in to comment.