diff --git a/src/spikeanalysis/plotbase.py b/src/spikeanalysis/plotbase.py index 2a383db..aecd3aa 100644 --- a/src/spikeanalysis/plotbase.py +++ b/src/spikeanalysis/plotbase.py @@ -60,6 +60,9 @@ def convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple: dpi = plot_kwargs.pop("dpi", self.dpi) x_lim = plot_kwargs.pop("xlim", None) y_lim = plot_kwargs.pop("ylim", None) + fontname = plot_kwargs.pop("fontname", None) + fontstyle = plot_kwargs.pop("fontstyle", "normal") + fontsize = plot_kwargs.fontsize("fontsize", None) title = plot_kwargs.pop("title", self.title) cmap = plot_kwargs.pop("cmap", self.cmap) @@ -67,9 +70,24 @@ def convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple: x_axis = plot_kwargs.pop("x_axis", self.x_axis) y_axis = plot_kwargs.pop("y_axis", self.y_axis) - PlotKwargs = namedtuple("PlotKwargs", ["figsize", "dpi", "x_lim", "y_lim", "title", "cmap", "x_axis", "y_axis"]) - - plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis) + PlotKwargs = namedtuple( + "PlotKwargs", + [ + "figsize", + "dpi", + "x_lim", + "y_lim", + "title", + "cmap", + "x_axis", + "y_axis", + "fontname", + "fontstyle", + "fontsize", + ], + ) + + plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize) return plot_kwargs diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index def6fba..05f3606 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -317,7 +317,12 @@ def _plot_scores( for idx, sub_ax in enumerate(axes.flat): im = sub_ax.imshow(sorted_z_scores[:, idx, :], vmin=vmin, vmax=vmax, cmap=cmap, aspect="auto") - sub_ax.set_xlabel(self.x_axis, fontsize="small") + sub_ax.set_xlabel( + self.x_axis, + fontsize=plot_kwargs.fontsize, + fontname=plot_kwargs.fontname, + fontstyle=plot_kwargs.fontstyle, + ) sub_ax.set_xticks([i * bins_length for i in range(7)]) sub_ax.set_xticklabels([round(bins[i * bins_length], 4) if i < 7 else z_window[1] for i in range(7)]) if idx == 0: @@ -358,9 +363,19 @@ def _plot_scores( cbar_label = "Raw Firing" plt.colorbar(im, cax=cax, label=cbar_label) # Similar to fig.colorbar(im, cax = cax) if plot_kwargs.title is None: - plt.title(f"{stimulus}") + plt.title( + f"{stimulus}", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) else: - plt.title(plot_kwargs.title) + plt.title( + plot_kwargs.title, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) plt.figure(dpi=plot_kwargs.dpi) plt.show() @@ -488,16 +503,32 @@ def plot_raster( ax.plot([0, 0], [0, np.nanmax(raster_y) + 1], color="red", linestyle=":") ax.plot([events, events], [0, np.nanmax(raster_y) + 1], color="red", linestyle=":") - ax.set(xlabel=plot_kwargs.x_axis, ylabel=ylabel) + ax.set( + xlabel=plot_kwargs.x_axis, + ylabel=ylabel, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) self.set_plot_kwargs(ax, plot_kwargs) plt.grid(False) plt.tight_layout() self._despine(ax) if plot_kwargs.title is None: - plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8) + plt.title( + f"{stimulus}: {self.data.cluster_ids[idy]}", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) else: - plt.title(plot_kwargs.title) + plt.title( + plot_kwargs.title, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) plt.figure(dpi=plot_kwargs.dpi) plt.show() @@ -647,16 +678,36 @@ def plot_sm_fr( ax.set(ylim=(0, np.max(mean_smoothed_psth) + np.max(stderr) + 1)) self.set_plot_kwargs(ax, plot_kwargs) - ax.set_ylabel(ylabel) - ax.set_xlabel(plot_kwargs.x_axis) + ax.set_ylabel( + ylabel, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) + ax.set_xlabel( + plot_kwargs.x_axis, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) plt.tight_layout() self._despine(ax) if plot_kwargs.title is not None: - plt.title(plot_kwargs.title) + plt.title( + plot_kwargs.title, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) else: - plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8) + plt.title( + f"{stimulus}: {self.data.cluster_ids[cluster_number]}", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) plt.figure(dpi=plot_kwargs.dpi) plt.show() @@ -810,10 +861,25 @@ def plot_latencies(self, colors="red", plot_kwargs={}): fig, ax = plt.subplots(figsize=plot_kwargs.figsize) ax.hist(lat_by_neuron, density=True, bins=bins, color=color, alpha=0.8) ax.hist(shufl_bsl_neuron, density=True, bins=bins, color="k", alpha=0.8) - ax.set_xlabel("Time (ms)", fontsize="small") - ax.set_ylabel("Counts", fontsize="small") + ax.set_xlabel( + "Time (ms)", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) + ax.set_ylabel( + "Counts", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) self.set_plot_kwargs(ax, plot_kwargs) - plt.title(f"{stimulus.title()}: {self.data.cluster_ids[neuron]}") + plt.title( + f"{stimulus.title()}: {self.data.cluster_ids[neuron]}", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) self._despine(ax) plt.tight_layout() plt.figure(dpi=plot_kwargs.dpi) @@ -837,7 +903,9 @@ def plot_isi(self): ax.set_xlabel("Time (ms)") ax.set_ylabel("Counts (Normalized)") self._despine(ax) - plt.title(f"ISI {cluster}") + plt.title( + f"ISI {cluster}", + ) plt.tight_layout() plt.figure(dpi=self.dpi) plt.show() @@ -891,10 +959,25 @@ def plot_event_isi(self, colors: str | dict, include_ids: list | np.array | None ax.stairs(sub_bsl, edges=bins, fill=True, color="k", alpha=0.7) ax.stairs(sub_stim_isi, edges=bins, fill=True, color=color, alpha=0.7) self.set_plot_kwargs(ax, plot_kwargs) - ax.set_xlabel("Time (ms)") - ax.set_ylabel("Counts (Normalized)") + ax.set_xlabel( + "Time (ms)", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) + ax.set_ylabel( + "Counts (Normalized)", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) self._despine(ax) - plt.title(f"isi vs bsl {stimulus}: {self.data.cluster_ids[row]}") + plt.title( + f"isi vs bsl {stimulus}: {self.data.cluster_ids[row]}", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) plt.tight_layout() plt.figure(dpi=plot_kwargs.dpi) plt.show() @@ -1157,10 +1240,22 @@ def _plot_one_trace( ) self.set_plot_kwargs(ax, plot_kwargs) - ax.set_xlabel("Time (s)") - ax.set_ylabel(plot_kwargs.y_axis) + ax.set_xlabel( + "Time (s)", fontsize=plot_kwargs.fontsize, fontstyle=plot_kwargs.fontstyle, fontname=plot_kwargs.fontname + ) + ax.set_ylabel( + plot_kwargs.y_axis, + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) self._despine(ax) - plt.title(f"trace {stim}") + plt.title( + f"trace {stim}", + fontsize=plot_kwargs.fontsize, + fontstyle=plot_kwargs.fontstyle, + fontname=plot_kwargs.fontname, + ) plt.tight_layout() plt.figure(dpi=plot_kwargs.dpi) plt.show()