diff --git a/EXPtools/visuals/visualize.py b/EXPtools/visuals/visualize.py index 110ea81..e32dd25 100644 --- a/EXPtools/visuals/visualize.py +++ b/EXPtools/visuals/visualize.py @@ -2,7 +2,9 @@ import numpy as np import matplotlib.pyplot as plt -def make_basis_plot(basis, savefile=None, nsnap='mean', y=0.92, dpi=200): +def make_basis_plot(basis, lmax=6, nmax=20, + savefile=None, nsnap='mean', y=0.92, dpi=200, + lrmin=0.5, lrmax=2.7, rnum=100): """ Plots the potential of the basis functions for different values of l and n. @@ -13,45 +15,43 @@ def make_basis_plot(basis, savefile=None, nsnap='mean', y=0.92, dpi=200): savefile (str, optional): name of the file to save the plot as nsnap (str, optional): description of the snapshot being plotted y: float (optional - vertical position of the main title + vertical position of the main title dpi: int (optional) resolution of the plot in dots per inch Returns ------- - None - None + tuple: A tuple containing fig and ax. """ - # Set up grid for plotting potential - lrmin, lrmax, rnum = 0.5, 2.7, 100 halo_grid = basis.getBasis(lrmin, lrmax, rnum) r = np.linspace(lrmin, lrmax, rnum) r = np.power(10.0, r) # Create subplots and plot potential for each l and n - fig, ax = plt.subplots(4, 5, figsize=(6,6), dpi=dpi, - sharex='col', sharey='row') + fig, ax = plt.subplots(lmax, 1, figsize=(10, 3*lmax), dpi=dpi, + sharex='col', sharey='row') plt.subplots_adjust(wspace=0, hspace=0) ax = ax.flatten() - for l in range(len(ax)): - ax[l].set_title(f"$\ell = {l}$", y=0.8, fontsize=6) - for n in range(20): - ax[l].semilogx(r, halo_grid[l][n]['potential'], '-', label="n={}".format(n), lw=0.5) + for l in range(lmax): + ax[l].set_title(f"$\ell = {l}$", y=0.8, fontsize=16) + for n in range(nmax): + ax[l].semilogx(r, halo_grid[l][n]['potential'], + '-', label="n={}".format(n), lw=0.5) # Add labels and main title fig.supylabel('Potential', weight='bold', x=-0.02) fig.supxlabel('Radius', weight='bold', y=0.02) - fig.suptitle(f'nsnap = {nsnap}', - fontsize=12, - weight='bold', + fig.suptitle(f'nsnap = {nsnap}', + fontsize=12, + weight='bold', y=y, - ) - + ) # Save plot if a filename was provided if savefile: plt.savefig(f'{savefile}', bbox_inches='tight') + return (fig, ax) def find_field(basis, coefficients, time=0, xyz=(0, 0, 0), property='dens', include_monopole=True): """