diff --git a/mdet_tests/test_mdet_regression.py b/mdet_tests/test_mdet_regression.py index c4047a80..e48e565b 100644 --- a/mdet_tests/test_mdet_regression.py +++ b/mdet_tests/test_mdet_regression.py @@ -265,7 +265,7 @@ def test_mdet_regression(fname, write=False): ), } else: - assert col in ["shear"] + assert col in ["shear", "shear_bands"] if __name__ == "__main__": diff --git a/ngmix/jacobian/jacobian.py b/ngmix/jacobian/jacobian.py index f7f555f1..044525a3 100644 --- a/ngmix/jacobian/jacobian.py +++ b/ngmix/jacobian/jacobian.py @@ -318,8 +318,8 @@ def _finish_init(self, row0, col0, dvdrow, dvdcol, dudrow, dudcol): def __repr__(self): fmt = ( - 'row0: %-10.5g col0: %-10.5g dvdrow: %-10.5g ' - 'dvdcol: %-10.5g dudrow: %-10.5g dudcol: %-10.5g' + 'ngmix.Jacobian(row=%r, col=%r, dvdrow=%r, ' + 'dvdcol=%r, dudrow=%r, dudcol=%r)' ) return fmt % (self.row0, self.col0, diff --git a/ngmix/metacal/metacal.py b/ngmix/metacal/metacal.py index a79f3d74..9b938671 100644 --- a/ngmix/metacal/metacal.py +++ b/ngmix/metacal/metacal.py @@ -6,6 +6,7 @@ """ import copy import logging +from functools import lru_cache import numpy as np from ..gexceptions import GMixRangeError, BootPSFFailure from ..shape import Shape @@ -20,6 +21,14 @@ logger = logging.getLogger(__name__) +@lru_cache(maxsize=128) +def _cached_galsim_stuff(img, wcs_repr, xinterp): + import galsim + image = galsim.Image(np.array(img), wcs=eval(wcs_repr)) + image_int = galsim.InterpolatedImage(image, x_interpolant=xinterp) + return image, image_int + + class MetacalDilatePSF(object): """ Create manipulated images for use in metacalibration @@ -355,22 +364,24 @@ def _set_data(self): # these would share data with the original numpy arrays, make copies # to be sure they don't get modified # - self.image = galsim.Image(obs.image.copy(), - wcs=self.get_wcs()) - - self.psf_image = galsim.Image(obs.psf.image.copy(), - wcs=self.get_psf_wcs()) + image, image_int = _cached_galsim_stuff( + tuple(tuple(ii) for ii in obs.image.copy()), + repr(self.get_wcs()), + self.interp, + ) + self.image = image + self.image_int = image_int - # interpolated psf image - psf_int = galsim.InterpolatedImage(self.psf_image, - x_interpolant=self.interp) + psf_image, psf_int = _cached_galsim_stuff( + tuple(tuple(ii) for ii in obs.psf.image.copy()), + repr(self.get_psf_wcs()), + self.interp, + ) + self.psf_image = psf_image # this can be used to deconvolve the psf from the galaxy image psf_int_inv = galsim.Deconvolve(psf_int) - self.image_int = galsim.InterpolatedImage(self.image, - x_interpolant=self.interp) - # deconvolved galaxy image, psf+pixel removed self.image_int_nopsf = galsim.Convolve(self.image_int, psf_int_inv) diff --git a/ngmix/prepsfmom.py b/ngmix/prepsfmom.py index 2994e9bd..778c0635 100644 --- a/ngmix/prepsfmom.py +++ b/ngmix/prepsfmom.py @@ -1,4 +1,5 @@ import logging +import functools import numpy as np import scipy.fft as fft @@ -84,14 +85,14 @@ def _meas(self, obs, psf_obs, return_kernels): eff_pad_factor = target_dim / obs.image.shape[0] # pad image, psf and weight map, get FFTs, apply cen_phases - kim, im_row, im_col = _zero_pad_and_compute_fft( + kim, im_row, im_col = _zero_pad_and_compute_fft_cached( obs.image, obs.jacobian.row0, obs.jacobian.col0, target_dim, self.ap_rad, ) fft_dim = kim.shape[0] if psf_obs is not None: - kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft( + kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft_cached( psf_obs.image, psf_obs.jacobian.row0, psf_obs.jacobian.col0, target_dim, @@ -128,17 +129,17 @@ def _meas(self, obs, psf_obs, return_kernels): # now build the kernels if self.kernel == "ksigma": kernels = _ksigma_kernels( - target_dim, - self.fwhm, - obs.jacobian.dvdrow, obs.jacobian.dvdcol, - obs.jacobian.dudrow, obs.jacobian.dudcol, + int(target_dim), + float(self.fwhm), + float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol), + float(obs.jacobian.dudrow), float(obs.jacobian.dudcol), ) elif self.kernel in ["gauss", "pgauss"]: kernels = _gauss_kernels( - target_dim, - self.fwhm, - obs.jacobian.dvdrow, obs.jacobian.dvdcol, - obs.jacobian.dudrow, obs.jacobian.dudcol, + int(target_dim), + float(self.fwhm), + float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol), + float(obs.jacobian.dudrow), float(obs.jacobian.dudcol), ) else: raise ValueError( @@ -365,7 +366,7 @@ def _compute_cen_phase_shift(cen_row, cen_col, dim, msk=None): return np.cos(kcen) + 1j*np.sin(kcen) -def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad): +def _zero_pad_and_compute_fft_impl(im, cen_row, cen_col, target_dim, ap_rad): """zero pad and compute the FFT Returns the fft, cen_row in the padded image, and cen_col in the padded image. @@ -382,6 +383,30 @@ def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad): return kpim, pad_cen_row, pad_cen_col +# see https://stackoverflow.com/a/52332109 for how this works +@functools.lru_cache(maxsize=128) +def _zero_pad_and_compute_fft_cached_impl( + im_tuple, cen_row, cen_col, target_dim, ap_rad +): + return _zero_pad_and_compute_fft_impl( + np.array(im_tuple), cen_row, cen_col, target_dim, ap_rad + ) + + +@functools.wraps(_zero_pad_and_compute_fft_impl) +def _zero_pad_and_compute_fft_cached(im, cen_row, cen_col, target_dim, ap_rad): + return _zero_pad_and_compute_fft_cached_impl( + tuple(tuple(ii) for ii in im), + float(cen_row), float(cen_col), int(target_dim), float(ap_rad) + ) + + +_zero_pad_and_compute_fft_cached.cache_info \ + = _zero_pad_and_compute_fft_cached_impl.cache_info +_zero_pad_and_compute_fft_cached.cache_clear \ + = _zero_pad_and_compute_fft_cached_impl.cache_clear + + def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5): """deconvolve the PSF from an image in place. @@ -398,6 +423,7 @@ def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5): return kim, kpsf_im, msk +@functools.lru_cache(maxsize=128) def _ksigma_kernels( dim, kernel_size, @@ -505,6 +531,7 @@ def _ksigma_kernels( ) +@functools.lru_cache(maxsize=128) def _gauss_kernels( dim, kernel_size, diff --git a/ngmix/tests/_galsim_sims.py b/ngmix/tests/_galsim_sims.py index 13e22ffa..2d40d054 100644 --- a/ngmix/tests/_galsim_sims.py +++ b/ngmix/tests/_galsim_sims.py @@ -3,13 +3,12 @@ import galsim -def _get_obs(rng, set_noise_image=False, noise=1.0e-6): +def _get_obs(rng, set_noise_image=False, noise=1.0e-6, psf_fwhm=0.9, n=None): psf_noise = 1.0e-6 scale = 0.263 - psf_fwhm = 0.9 gal_fwhm = 0.7 psf = galsim.Gaussian(fwhm=psf_fwhm) @@ -18,7 +17,10 @@ def _get_obs(rng, set_noise_image=False, noise=1.0e-6): obj = galsim.Convolve(psf, obj0) psf_im = psf.drawImage(scale=scale).array - im = obj.drawImage(scale=scale).array + if n is not None: + im = obj.drawImage(scale=scale, nx=n, ny=n).array + else: + im = obj.drawImage(scale=scale).array psf_im += rng.normal(scale=psf_noise, size=psf_im.shape) im += rng.normal(scale=noise, size=im.shape) diff --git a/ngmix/tests/test_metacal_cache.py b/ngmix/tests/test_metacal_cache.py new file mode 100644 index 00000000..43ab99df --- /dev/null +++ b/ngmix/tests/test_metacal_cache.py @@ -0,0 +1,41 @@ +import time +import numpy as np +import ngmix +import ngmix.metacal.metacal +from ._galsim_sims import _get_obs +from ..metacal.metacal import _cached_galsim_stuff + + +def test_metacal_cache(): + # first warm up numba + rng = np.random.RandomState(seed=100) + obs = _get_obs(rng, noise=0.005, set_noise_image=True, psf_fwhm=0.8, n=300) + t0 = time.time() + ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"]) + t0 = time.time() - t0 + print("first time: %r seconds" % t0, flush=True) + print(_cached_galsim_stuff.cache_info(), flush=True) + + # now cache it + rng = np.random.RandomState(seed=10) + obs = _get_obs(rng, noise=0.005, set_noise_image=True, n=300) + t1 = time.time() + ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"]) + t1 = time.time() - t1 + print("second time: %r seconds" % t1, flush=True) + print(_cached_galsim_stuff.cache_info(), flush=True) + + # now use cache + rng = np.random.RandomState(seed=10) + obs = _get_obs(rng, noise=0.005, set_noise_image=True, n=300) + t2 = time.time() + ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"]) + t2 = time.time() - t2 + print("third time: %r seconds (< %r?)" % (t2, t1*0.7), flush=True) + print(_cached_galsim_stuff.cache_info(), flush=True) + + # numba should be slower always but we do not care how much + assert t1 < t0 + + # we expect roughly 30% gains + assert t2 < t1*0.7 diff --git a/ngmix/tests/test_prepsfmom.py b/ngmix/tests/test_prepsfmom.py index d24a1ff6..f82e6aec 100644 --- a/ngmix/tests/test_prepsfmom.py +++ b/ngmix/tests/test_prepsfmom.py @@ -1,11 +1,15 @@ import galsim import numpy as np import pytest +import time +from flaky import flaky from ngmix.prepsfmom import ( KSigmaMom, PGaussMom, _build_square_apodization_mask, PrePSFMom, + _gauss_kernels, + _zero_pad_and_compute_fft_cached_impl, ) from ngmix import Jacobian from ngmix import Observation @@ -94,6 +98,135 @@ def test_prepsfmom_raises_badjacob(cls): assert "same WCS Jacobia" in str(e.value) +@flaky +def test_prepsfmom_speed_and_cache(): + image_size = 48 + psf_image_size = 53 + pixel_scale = 0.263 + fwhm = 0.9 + psf_fwhm = 0.9 + snr = 20 + mom_fwhm = 2 + + rng = np.random.RandomState(seed=100) + + cen = (image_size - 1)/2 + psf_cen = (psf_image_size - 1)/2 + gs_wcs = galsim.ShearWCS( + pixel_scale, galsim.Shear(g1=-0.1, g2=0.06)).jacobian() + scale = np.sqrt(gs_wcs.pixelArea()) + shift = rng.uniform(low=-scale/2, high=scale/2, size=2) + psf_shift = rng.uniform(low=-scale/2, high=scale/2, size=2) + xy = gs_wcs.toImage(galsim.PositionD(shift)) + psf_xy = gs_wcs.toImage(galsim.PositionD(psf_shift)) + + jac = Jacobian( + y=cen + xy.y, x=cen + xy.x, + dudx=gs_wcs.dudx, dudy=gs_wcs.dudy, + dvdx=gs_wcs.dvdx, dvdy=gs_wcs.dvdy) + + psf_jac = Jacobian( + y=psf_cen + psf_xy.y, x=psf_cen + psf_xy.x, + dudx=gs_wcs.dudx, dudy=gs_wcs.dudy, + dvdx=gs_wcs.dvdx, dvdy=gs_wcs.dvdy) + + gal = galsim.Gaussian( + fwhm=fwhm + ).shear( + g1=-0.1, g2=0.2 + ).withFlux( + 400 + ).shift( + dx=shift[0], dy=shift[1] + ) + psf = galsim.Gaussian( + fwhm=psf_fwhm + ).shear( + g1=0.3, g2=-0.15 + ) + im = galsim.Convolve([gal, psf]).drawImage( + nx=image_size, + ny=image_size, + wcs=gs_wcs + ).array + noise = np.sqrt(np.sum(im**2)) / snr + wgt = np.ones_like(im) / noise**2 + + psf_im = psf.shift( + dx=psf_shift[0], dy=psf_shift[1] + ).drawImage( + nx=psf_image_size, + ny=psf_image_size, + wcs=gs_wcs + ).array + + # now we test the speed + caching + _gauss_kernels.cache_clear() + _zero_pad_and_compute_fft_cached_impl.cache_clear() + + # the first fit will do numba stuff, so we exclude it + # we also perturb the various inputs to fool our caches + fitter = PGaussMom( + fwhm=mom_fwhm + 1e-3, + ) + + obs = Observation( + image=im + 1e-6, + weight=wgt, + jacobian=jac, + psf=Observation(image=psf_im + 1e-8, jacobian=psf_jac), + ) + + dt = time.time() + fitter.go(obs=obs) + dt1 = time.time() - dt + print("\n%0.4f ms for first fit" % (dt1*1000)) + + # we miss once here for kernels, twice for images + assert _gauss_kernels.cache_info().misses == 1 + assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 2 + + # the second fit will have numba cached, but not the other kernel and FFT caches + fitter = PGaussMom( + fwhm=mom_fwhm, + ) + + obs = Observation( + image=im, + weight=wgt, + jacobian=jac, + psf=Observation(image=psf_im, jacobian=psf_jac), + ) + + dt = time.time() + fitter.go(obs=obs) + dt2 = time.time() - dt + print("%0.4f ms for second fit" % (dt2*1000)) + + # we miss twice for kernels, total of 3 times since psf changed + assert _gauss_kernels.cache_info().misses == 2 + assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 4 + + # now we test with full caching + nfit = 1000 + dt = time.time() + for _ in range(nfit): + with obs.writeable(): + obs.image += 1e-6 + fitter.go(obs=obs) + dt3 = time.time() - dt + + print("%0.4f ms per fit" % (dt3/nfit*1000)) + + # we should never miss again for the calls above + assert _gauss_kernels.cache_info().misses == 2 + assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 4 + nfit + + # if numba stuff is cached this does not work so commented out + # assert dt2 < dt1 + assert dt3/nfit < dt2*0.6 + + def _stack_list_of_dicts(res): def _get_dtype(v): if isinstance(v, float):