Skip to content

Commit

Permalink
update plotting to allow font editing
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Dec 6, 2023
1 parent 7513202 commit 89b4295
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 23 deletions.
24 changes: 21 additions & 3 deletions src/spikeanalysis/plotbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,34 @@ def convert_plot_kwargs(self, plot_kwargs: dict) -> namedtuple:
dpi = plot_kwargs.pop("dpi", self.dpi)
x_lim = plot_kwargs.pop("xlim", None)
y_lim = plot_kwargs.pop("ylim", None)
fontname = plot_kwargs.pop("fontname", None)
fontstyle = plot_kwargs.pop("fontstyle", "normal")
fontsize = plot_kwargs.fontsize("fontsize", None)

title = plot_kwargs.pop("title", self.title)
cmap = plot_kwargs.pop("cmap", self.cmap)

x_axis = plot_kwargs.pop("x_axis", self.x_axis)
y_axis = plot_kwargs.pop("y_axis", self.y_axis)

PlotKwargs = namedtuple("PlotKwargs", ["figsize", "dpi", "x_lim", "y_lim", "title", "cmap", "x_axis", "y_axis"])

plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis)
PlotKwargs = namedtuple(
"PlotKwargs",
[
"figsize",
"dpi",
"x_lim",
"y_lim",
"title",
"cmap",
"x_axis",
"y_axis",
"fontname",
"fontstyle",
"fontsize",
],
)

plot_kwargs = PlotKwargs(figsize, dpi, x_lim, y_lim, title, cmap, x_axis, y_axis, fontname, fontstyle, fontsize)

return plot_kwargs

Expand Down
135 changes: 115 additions & 20 deletions src/spikeanalysis/spike_plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,12 @@ def _plot_scores(

for idx, sub_ax in enumerate(axes.flat):
im = sub_ax.imshow(sorted_z_scores[:, idx, :], vmin=vmin, vmax=vmax, cmap=cmap, aspect="auto")
sub_ax.set_xlabel(self.x_axis, fontsize="small")
sub_ax.set_xlabel(
self.x_axis,
fontsize=plot_kwargs.fontsize,
fontname=plot_kwargs.fontname,
fontstyle=plot_kwargs.fontstyle,
)
sub_ax.set_xticks([i * bins_length for i in range(7)])
sub_ax.set_xticklabels([round(bins[i * bins_length], 4) if i < 7 else z_window[1] for i in range(7)])
if idx == 0:
Expand Down Expand Up @@ -358,9 +363,19 @@ def _plot_scores(
cbar_label = "Raw Firing"
plt.colorbar(im, cax=cax, label=cbar_label) # Similar to fig.colorbar(im, cax = cax)
if plot_kwargs.title is None:
plt.title(f"{stimulus}")
plt.title(
f"{stimulus}",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
else:
plt.title(plot_kwargs.title)
plt.title(
plot_kwargs.title,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
plt.figure(dpi=plot_kwargs.dpi)
plt.show()

Expand Down Expand Up @@ -488,16 +503,32 @@ def plot_raster(
ax.plot([0, 0], [0, np.nanmax(raster_y) + 1], color="red", linestyle=":")
ax.plot([events, events], [0, np.nanmax(raster_y) + 1], color="red", linestyle=":")

ax.set(xlabel=plot_kwargs.x_axis, ylabel=ylabel)
ax.set(
xlabel=plot_kwargs.x_axis,
ylabel=ylabel,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
self.set_plot_kwargs(ax, plot_kwargs)
plt.grid(False)
plt.tight_layout()

self._despine(ax)
if plot_kwargs.title is None:
plt.title(f"{stimulus}: {self.data.cluster_ids[idy]}", fontsize=8)
plt.title(
f"{stimulus}: {self.data.cluster_ids[idy]}",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
else:
plt.title(plot_kwargs.title)
plt.title(
plot_kwargs.title,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
plt.figure(dpi=plot_kwargs.dpi)
plt.show()

Expand Down Expand Up @@ -647,16 +678,36 @@ def plot_sm_fr(

ax.set(ylim=(0, np.max(mean_smoothed_psth) + np.max(stderr) + 1))
self.set_plot_kwargs(ax, plot_kwargs)
ax.set_ylabel(ylabel)
ax.set_xlabel(plot_kwargs.x_axis)
ax.set_ylabel(
ylabel,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
ax.set_xlabel(
plot_kwargs.x_axis,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
plt.tight_layout()

self._despine(ax)

if plot_kwargs.title is not None:
plt.title(plot_kwargs.title)
plt.title(
plot_kwargs.title,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
else:
plt.title(f"{stimulus}: {self.data.cluster_ids[cluster_number]}", fontsize=8)
plt.title(
f"{stimulus}: {self.data.cluster_ids[cluster_number]}",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
plt.figure(dpi=plot_kwargs.dpi)
plt.show()

Expand Down Expand Up @@ -810,10 +861,25 @@ def plot_latencies(self, colors="red", plot_kwargs={}):
fig, ax = plt.subplots(figsize=plot_kwargs.figsize)
ax.hist(lat_by_neuron, density=True, bins=bins, color=color, alpha=0.8)
ax.hist(shufl_bsl_neuron, density=True, bins=bins, color="k", alpha=0.8)
ax.set_xlabel("Time (ms)", fontsize="small")
ax.set_ylabel("Counts", fontsize="small")
ax.set_xlabel(
"Time (ms)",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
ax.set_ylabel(
"Counts",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
self.set_plot_kwargs(ax, plot_kwargs)
plt.title(f"{stimulus.title()}: {self.data.cluster_ids[neuron]}")
plt.title(
f"{stimulus.title()}: {self.data.cluster_ids[neuron]}",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
self._despine(ax)
plt.tight_layout()
plt.figure(dpi=plot_kwargs.dpi)
Expand All @@ -837,7 +903,9 @@ def plot_isi(self):
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Counts (Normalized)")
self._despine(ax)
plt.title(f"ISI {cluster}")
plt.title(
f"ISI {cluster}",
)
plt.tight_layout()
plt.figure(dpi=self.dpi)
plt.show()
Expand Down Expand Up @@ -891,10 +959,25 @@ def plot_event_isi(self, colors: str | dict, include_ids: list | np.array | None
ax.stairs(sub_bsl, edges=bins, fill=True, color="k", alpha=0.7)
ax.stairs(sub_stim_isi, edges=bins, fill=True, color=color, alpha=0.7)
self.set_plot_kwargs(ax, plot_kwargs)
ax.set_xlabel("Time (ms)")
ax.set_ylabel("Counts (Normalized)")
ax.set_xlabel(
"Time (ms)",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
ax.set_ylabel(
"Counts (Normalized)",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
self._despine(ax)
plt.title(f"isi vs bsl {stimulus}: {self.data.cluster_ids[row]}")
plt.title(
f"isi vs bsl {stimulus}: {self.data.cluster_ids[row]}",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
plt.tight_layout()
plt.figure(dpi=plot_kwargs.dpi)
plt.show()
Expand Down Expand Up @@ -1157,10 +1240,22 @@ def _plot_one_trace(
)

self.set_plot_kwargs(ax, plot_kwargs)
ax.set_xlabel("Time (s)")
ax.set_ylabel(plot_kwargs.y_axis)
ax.set_xlabel(
"Time (s)", fontsize=plot_kwargs.fontsize, fontstyle=plot_kwargs.fontstyle, fontname=plot_kwargs.fontname
)
ax.set_ylabel(
plot_kwargs.y_axis,
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
self._despine(ax)
plt.title(f"trace {stim}")
plt.title(
f"trace {stim}",
fontsize=plot_kwargs.fontsize,
fontstyle=plot_kwargs.fontstyle,
fontname=plot_kwargs.fontname,
)
plt.tight_layout()
plt.figure(dpi=plot_kwargs.dpi)
plt.show()
Expand Down

0 comments on commit 89b4295

Please sign in to comment.