From 8e2bbfa2a135093dfa17f2331972b0a3838022e3 Mon Sep 17 00:00:00 2001 From: Erin Sheldon Date: Wed, 23 Oct 2024 21:59:25 -0400 Subject: [PATCH] convert GMixND.plot to use matplotlib --- CHANGES.md | 4 ++++ ngmix/gmix_ndim/gmix_ndim.py | 33 ++++++++++++++++++++++----------- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index f940cd8d..9b17c333 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -6,6 +6,10 @@ Only affected rng such as np.random.default_rng that checks range for uniform +## compatibility + + - Convert GMixND.plot to use matplotlib + ## v2.3.1 ### new features diff --git a/ngmix/gmix_ndim/gmix_ndim.py b/ngmix/gmix_ndim/gmix_ndim.py index c353cb0d..d63fb07c 100644 --- a/ngmix/gmix_ndim/gmix_ndim.py +++ b/ngmix/gmix_ndim/gmix_ndim.py @@ -160,10 +160,12 @@ def plot( ------- plot object """ + import numpy as np import esutil as eu - import hickory + import matplotlib.pyplot as mplt + from itertools import cycle - plt = hickory.Plot(**plot_kws) + fig, ax = mplt.subplots() if data is not None: @@ -183,15 +185,19 @@ def plot( dx_model = dx_data/10 npts = int((max - min)/dx_model) - xvals = numpy.linspace( + xvals = np.linspace( min, max, npts, ) dx_model = xvals[1] - xvals[0] - plt.bar(hd['center'], hd['hist'], label='data', width=dx_data, - alpha=0.5, color='#a6a6a6') + ax.bar( + hd['center'], hd['hist'], label='data', + width=dx_data, + alpha=0.5, + color='#a6a6a6', + ) else: if npts is None: @@ -201,7 +207,7 @@ def plot( if max is None: raise ValueError('send max if not sending data') - xvals = numpy.linspace(min, max, npts) + xvals = np.linspace(min, max, npts) predicted = self.get_prob_array(xvals) @@ -212,21 +218,26 @@ def plot( else: fac = 1 - plt.curve(xvals, predicted, label='model') + lines = ["-", "--", "-.", ":"] + linecycler = cycle(lines) + + ax.plot(xvals, predicted, ls=next(linecycler), label='model') + for i in range(self.ngauss): predicted = fac*self.get_prob_array(xvals, component=i) label = 'component %d' % i - plt.curve(xvals, predicted, label=label) + ax.plot(xvals, predicted, label=label, ls=next(linecycler)) + ax.legend() if show: - plt.show() + mplt.show() if file is not None: print('writing:', file) - plt.savefig(file, dpi=dpi) + fig.savefig(file, dpi=dpi) - return plt + return fig, ax def save_mixture(self, fname): """