diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 46628b5..67e71db 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -718,13 +718,38 @@ def plot_isi(self): fig, ax = plt.subplots(figsize=self.figsize) ax.hist(isi, density=True, bins=bins, color="k") ax.set_xlabel("Time (ms)") - ax.set_ylabel("Counts") + ax.set_ylabel("Counts (Normalized)") self._despine(ax) plt.title(f"ISI {cluster}") plt.tight_layout() plt.figure(dpi=self.dpi) plt.show() + def plot_event_isi(self): + try: + final_isi = self.data.isi + except AttributeError: + raise Exception("must run `compute_event_interspike_interval()") + + for stimulus, isis in final_isi.items(): + baseline = isis["bsl_isi"].sum(axis=1) + stimulus_isi = isis["isi"].sum(axis=1) + bins = isis["bins"] + for row in range(stimulus.shape[0]): + sub_bsl = baseline[row] / baseline[row].sum() + sub_stim_isi = stimulus_isi[row] / stimulus_isi[row].sum() + + fig, ax = plt.subplots(figsize=self.figsize) + ax.stairs(sub_bsl, edges=bins, fill=True, color="k", alpha=0.7) + ax.stairs(sub_stim_isi, edges=bins, fill=True, color="r", alpha=0.7) + ax.set_xlabel("Time (ms)") + ax.set_ylabel("Counts (Normalized)") + self._despine(ax) + plt.title(f"isi vs bsl {stimulus}: {self.data.cluster_ids[row]}") + plt.tight_layout() + plt.figure(dpi=self.dpi) + plt.show() + def plot_response_trace( self, type: Literal["zscore", "raw"] = "zscore",