Skip to content

Commit

Permalink
Merge pull request #243 from esheldon/gmix-nd-plot
Browse files Browse the repository at this point in the history
convert GMixND.plot to use matplotlib
  • Loading branch information
esheldon authored Oct 24, 2024
2 parents 180a27e + 8e2bbfa commit 74eddec
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
Only affected rng such as np.random.default_rng that checks
range for uniform

## compatibility

- Convert GMixND.plot to use matplotlib

## v2.3.1

### new features
Expand Down
33 changes: 22 additions & 11 deletions ngmix/gmix_ndim/gmix_ndim.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,12 @@ def plot(
-------
plot object
"""
import numpy as np
import esutil as eu
import hickory
import matplotlib.pyplot as mplt
from itertools import cycle

plt = hickory.Plot(**plot_kws)
fig, ax = mplt.subplots()

if data is not None:

Expand All @@ -183,15 +185,19 @@ def plot(
dx_model = dx_data/10
npts = int((max - min)/dx_model)

xvals = numpy.linspace(
xvals = np.linspace(
min,
max,
npts,
)
dx_model = xvals[1] - xvals[0]

plt.bar(hd['center'], hd['hist'], label='data', width=dx_data,
alpha=0.5, color='#a6a6a6')
ax.bar(
hd['center'], hd['hist'], label='data',
width=dx_data,
alpha=0.5,
color='#a6a6a6',
)

else:
if npts is None:
Expand All @@ -201,7 +207,7 @@ def plot(
if max is None:
raise ValueError('send max if not sending data')

xvals = numpy.linspace(min, max, npts)
xvals = np.linspace(min, max, npts)

predicted = self.get_prob_array(xvals)

Expand All @@ -212,21 +218,26 @@ def plot(
else:
fac = 1

plt.curve(xvals, predicted, label='model')
lines = ["-", "--", "-.", ":"]
linecycler = cycle(lines)

ax.plot(xvals, predicted, ls=next(linecycler), label='model')

for i in range(self.ngauss):
predicted = fac*self.get_prob_array(xvals, component=i)

label = 'component %d' % i
plt.curve(xvals, predicted, label=label)
ax.plot(xvals, predicted, label=label, ls=next(linecycler))

ax.legend()
if show:
plt.show()
mplt.show()

if file is not None:
print('writing:', file)
plt.savefig(file, dpi=dpi)
fig.savefig(file, dpi=dpi)

return plt
return fig, ax

def save_mixture(self, fname):
"""
Expand Down

0 comments on commit 74eddec

Please sign in to comment.