diff --git a/src/Elasticipy/SphericalFunction.py b/src/Elasticipy/SphericalFunction.py index 654814f..6ce8ee9 100644 --- a/src/Elasticipy/SphericalFunction.py +++ b/src/Elasticipy/SphericalFunction.py @@ -745,7 +745,7 @@ def plot_xyz_sections(self, n_theta=500, n_psi=100, color_minmax='blue', alpha_m return fig, axs def plot_as_pole_figure(self, n_theta=50, n_phi=200, n_psi=50, which='mean', projection='lambert', fig=None, - plot_type='imshow', **kwargs): + plot_type='imshow', show=True, title=None, subplot_args=(), subplot_kwargs=None, **kwargs): """ Generate a pole figure plot from spherical function evaluation. @@ -773,6 +773,12 @@ def plot_as_pole_figure(self, n_theta=50, n_phi=200, n_psi=50, which='mean', pro plot_type : str, optional Type of plot to generate. Can be 'imshow', 'contourf', or 'contour'. Default is 'imshow'. + show : bool, optional + Set whether to show the plot or not. Default is True. This must be turned off when using multiple subplots. + subplot_args : tuple, optional + Arguments to pass to the add_subplot() function. Default is None. + subplot_kwargs : dict, optional + Keyword arguments to pass to the add_subplot() function. Default is None. **kwargs : dict, optional Additional keyword arguments passed to the plotting functions. @@ -783,9 +789,11 @@ def plot_as_pole_figure(self, n_theta=50, n_phi=200, n_psi=50, which='mean', pro ax : matplotlib.axes.Axes The axes object containing the plot. """ + if subplot_kwargs is None: + subplot_kwargs = {} if fig is None: fig = plt.figure() - ax = add_polefigure(fig, projection=projection) + ax = add_polefigure(fig, *subplot_kwargs, projection=projection, **subplot_kwargs) phi = np.linspace(*self.domain[0], n_phi) theta = np.linspace(*self.domain[1], n_theta) psi = np.linspace(*self.domain[2], n_psi) @@ -813,6 +821,8 @@ def plot_as_pole_figure(self, n_theta=50, n_phi=200, n_psi=50, which='mean', pro else: raise ValueError(f'Unknown plot type: {plot_type}') ax.set_rlim(*self.domain[1]) + ax.set_title(title) fig.colorbar(sc) - plt.show() + if show: + plt.show() return fig, ax \ No newline at end of file