Skip to content

Commit

Permalink
add save function
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Oct 24, 2024
1 parent f0c0f24 commit d7301fa
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
17 changes: 16 additions & 1 deletion src/spikeanalysis/plotbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
"fontname": "The font to use",
"fontstyle": "The style to use for the font",
"fontsize": "The size of the text",
"save": "Whether to save images",
"format": "The format to save the image",
"extra_title": "Additional info to add to image title"
}


Expand Down Expand Up @@ -85,6 +88,10 @@ def _convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple:
x_axis = plot_kwargs.pop("x_axis", self.x_axis)
y_axis = plot_kwargs.pop("y_axis", self.y_axis)

save = plot_kwargs.pop("save", False)
format = plot_kwargs.pop("format", "png")
extra_title = plot_kwargs.pop("extra_title", '')

PlotKwargs = namedtuple(
"PlotKwargs",
[
Expand All @@ -99,10 +106,13 @@ def _convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple:
"fontname",
"fontstyle",
"fontsize",
"save",
"format",
"extra_title"
],
)

plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize)
plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize, save, format, extra_title)

return plot_kwargs

Expand All @@ -116,3 +126,8 @@ def set_plot_kwargs(self, ax: plt.axes, plot_kwargs: namedtuple):

if plot_kwargs.ylim is not None:
ax.set_ylim(plot_kwargs.ylim)

def _save_fig(self, cluster_number, extra_title='', format='png'):

title = f"{cluster_number}_{extra_title}"
plt.savefig(title, format=format)
9 changes: 9 additions & 0 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,10 @@ def _plot_scores(
fontname=plot_kwargs.fontname,
)
plt.figure(dpi=plot_kwargs.dpi)
if plot_kwargs.save and plot_kwargs.title is not None:
self._save_fig(plot_kwargs.title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format)
elif plot_kwargs.title is None:
print('give title to save heat map')
plt.show()

if reset_index:
Expand Down Expand Up @@ -590,6 +594,8 @@ def plot_raster(
fontname=plot_kwargs.fontname,
)
plt.figure(dpi=plot_kwargs.dpi)
if plot_kwargs.save:
self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format)
plt.show()

def plot_sm_fr(
Expand Down Expand Up @@ -770,6 +776,9 @@ def plot_sm_fr(
fontname=plot_kwargs.fontname,
)
plt.figure(dpi=plot_kwargs.dpi)

if plot_kwargs.save:
self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format)
plt.show()

def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True):
Expand Down

0 comments on commit d7301fa

Please sign in to comment.