diff --git a/src/spikeanalysis/plotbase.py b/src/spikeanalysis/plotbase.py index 51f0d78..e0cf25f 100644 --- a/src/spikeanalysis/plotbase.py +++ b/src/spikeanalysis/plotbase.py @@ -19,6 +19,9 @@ "fontname": "The font to use", "fontstyle": "The style to use for the font", "fontsize": "The size of the text", + "save": "Whether to save images", + "format": "The format to save the image", + "extra_title": "Additional info to add to image title" } @@ -85,6 +88,10 @@ def _convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple: x_axis = plot_kwargs.pop("x_axis", self.x_axis) y_axis = plot_kwargs.pop("y_axis", self.y_axis) + save = plot_kwargs.pop("save", False) + format = plot_kwargs.pop("format", "png") + extra_title = plot_kwargs.pop("extra_title", '') + PlotKwargs = namedtuple( "PlotKwargs", [ @@ -99,10 +106,13 @@ def _convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple: "fontname", "fontstyle", "fontsize", + "save", + "format", + "extra_title" ], ) - plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize) + plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize, save, format, extra_title) return plot_kwargs @@ -116,3 +126,8 @@ def set_plot_kwargs(self, ax: plt.axes, plot_kwargs: namedtuple): if plot_kwargs.ylim is not None: ax.set_ylim(plot_kwargs.ylim) + + def _save_fig(self, cluster_number, extra_title='', format='png'): + + title = f"{cluster_number}_{extra_title}" + plt.savefig(title, format=format) \ No newline at end of file diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 74ae146..1c1554c 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -435,6 +435,10 @@ def _plot_scores( fontname=plot_kwargs.fontname, ) plt.figure(dpi=plot_kwargs.dpi) + if plot_kwargs.save and plot_kwargs.title is not None: + self._save_fig(plot_kwargs.title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) + elif plot_kwargs.title is None: + print('give title to save heat map') plt.show() if reset_index: @@ -590,6 +594,8 @@ def plot_raster( fontname=plot_kwargs.fontname, ) plt.figure(dpi=plot_kwargs.dpi) + if plot_kwargs.save: + self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) plt.show() def plot_sm_fr( @@ -770,6 +776,9 @@ def plot_sm_fr( fontname=plot_kwargs.fontname, ) plt.figure(dpi=plot_kwargs.dpi) + + if plot_kwargs.save: + self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) plt.show() def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True):