diff --git a/pyproject.toml b/pyproject.toml index f9d23e6..25d99bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,4 +56,5 @@ omit = [ "src/spikeanalysis/analysis_utils/*", # utils are all numba "src/spikeanalysis/spike_plotter.py", # no testing for actual plotting yet "src/spikeanalysis/intrinsic_plotter.py", # no testing for plotting yet + "src/spikeanalysis/plotting.functions.py", # no test for plotting yet ] diff --git a/src/spikeanalysis/__init__.py b/src/spikeanalysis/__init__.py index edbfcf3..03d778e 100644 --- a/src/spikeanalysis/__init__.py +++ b/src/spikeanalysis/__init__.py @@ -6,6 +6,7 @@ from .analog_analysis import AnalogAnalysis from .curated_spike_analysis import CuratedSpikeAnalysis, read_responsive_neurons from .stats_functions import kolmo_smir_stats +from .plotting_functions import plot_piechart import importlib.metadata diff --git a/src/spikeanalysis/plotting_functions.py b/src/spikeanalysis/plotting_functions.py new file mode 100644 index 0000000..b1906e5 --- /dev/null +++ b/src/spikeanalysis/plotting_functions.py @@ -0,0 +1,46 @@ +from typing import Optional, Sequence +import matplotlib.pyplot as plt + + +def plot_piechart(self, wedges: Sequence, counts: Sequence, colors: Optional[Sequence] = None): + """Plots a piechart""" + + assert len(wedges) == len(counts), "each wedge needs a corresponding count" + assert not counts.index(0), "counts with 0 will display incorrectly" + assert counts[0] != 0, "counts with 0 will display incorrectly" + import numpy as np + + if self.figsize[0] <= 10: + fontsize = 10 + else: + fontsize = 14 + + f, ax = plt.subplots(figsize=self.figsize) + + if colors is None: + colors = [ + "#ff9999", + "#66b3ff", + "#99ff99", + "#FEC8D8", + "#ffcc99", + "#F6BF85", + "#B7ADED", + ] + + ax.pie( + counts, + labels=wedges, + autopct=lambda pct: "{:.1f}%\n(n={:d})".format(pct, int(np.round(pct / 100 * np.sum(counts)))), + shadow=False, + startangle=90, + colors=colors, + textprops={"fontsize": fontsize}, + ) + ax.axis("equal") + if self.title: + plt.title(self.title) + plt.tight_layout() + plt.figure(dpi=self.dpi) + + plt.show()