Skip to content

Commit

Permalink
Merge pull request #216 from esheldon/test-for-speed
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr authored Jun 21, 2022
2 parents 38c3790 + 5007cdb commit 0b90f27
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 28 deletions.
2 changes: 1 addition & 1 deletion mdet_tests/test_mdet_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_mdet_regression(fname, write=False):
),
}
else:
assert col in ["shear"]
assert col in ["shear", "shear_bands"]


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions ngmix/jacobian/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def _finish_init(self, row0, col0, dvdrow, dvdcol, dudrow, dudcol):

def __repr__(self):
fmt = (
'row0: %-10.5g col0: %-10.5g dvdrow: %-10.5g '
'dvdcol: %-10.5g dudrow: %-10.5g dudcol: %-10.5g'
'ngmix.Jacobian(row=%r, col=%r, dvdrow=%r, '
'dvdcol=%r, dudrow=%r, dudcol=%r)'
)
return fmt % (self.row0,
self.col0,
Expand Down
33 changes: 22 additions & 11 deletions ngmix/metacal/metacal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import copy
import logging
from functools import lru_cache
import numpy as np
from ..gexceptions import GMixRangeError, BootPSFFailure
from ..shape import Shape
Expand All @@ -20,6 +21,14 @@
logger = logging.getLogger(__name__)


@lru_cache(maxsize=128)
def _cached_galsim_stuff(img, wcs_repr, xinterp):
import galsim
image = galsim.Image(np.array(img), wcs=eval(wcs_repr))
image_int = galsim.InterpolatedImage(image, x_interpolant=xinterp)
return image, image_int


class MetacalDilatePSF(object):
"""
Create manipulated images for use in metacalibration
Expand Down Expand Up @@ -355,22 +364,24 @@ def _set_data(self):
# these would share data with the original numpy arrays, make copies
# to be sure they don't get modified
#
self.image = galsim.Image(obs.image.copy(),
wcs=self.get_wcs())

self.psf_image = galsim.Image(obs.psf.image.copy(),
wcs=self.get_psf_wcs())
image, image_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.image.copy()),
repr(self.get_wcs()),
self.interp,
)
self.image = image
self.image_int = image_int

# interpolated psf image
psf_int = galsim.InterpolatedImage(self.psf_image,
x_interpolant=self.interp)
psf_image, psf_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.psf.image.copy()),
repr(self.get_psf_wcs()),
self.interp,
)
self.psf_image = psf_image

# this can be used to deconvolve the psf from the galaxy image
psf_int_inv = galsim.Deconvolve(psf_int)

self.image_int = galsim.InterpolatedImage(self.image,
x_interpolant=self.interp)

# deconvolved galaxy image, psf+pixel removed
self.image_int_nopsf = galsim.Convolve(self.image_int,
psf_int_inv)
Expand Down
49 changes: 38 additions & 11 deletions ngmix/prepsfmom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import functools

import numpy as np
import scipy.fft as fft
Expand Down Expand Up @@ -84,14 +85,14 @@ def _meas(self, obs, psf_obs, return_kernels):
eff_pad_factor = target_dim / obs.image.shape[0]

# pad image, psf and weight map, get FFTs, apply cen_phases
kim, im_row, im_col = _zero_pad_and_compute_fft(
kim, im_row, im_col = _zero_pad_and_compute_fft_cached(
obs.image, obs.jacobian.row0, obs.jacobian.col0, target_dim,
self.ap_rad,
)
fft_dim = kim.shape[0]

if psf_obs is not None:
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft(
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft_cached(
psf_obs.image,
psf_obs.jacobian.row0, psf_obs.jacobian.col0,
target_dim,
Expand Down Expand Up @@ -128,17 +129,17 @@ def _meas(self, obs, psf_obs, return_kernels):
# now build the kernels
if self.kernel == "ksigma":
kernels = _ksigma_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
elif self.kernel in ["gauss", "pgauss"]:
kernels = _gauss_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
else:
raise ValueError(
Expand Down Expand Up @@ -365,7 +366,7 @@ def _compute_cen_phase_shift(cen_row, cen_col, dim, msk=None):
return np.cos(kcen) + 1j*np.sin(kcen)


def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
def _zero_pad_and_compute_fft_impl(im, cen_row, cen_col, target_dim, ap_rad):
"""zero pad and compute the FFT
Returns the fft, cen_row in the padded image, and cen_col in the padded image.
Expand All @@ -382,6 +383,30 @@ def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
return kpim, pad_cen_row, pad_cen_col


# see https://stackoverflow.com/a/52332109 for how this works
@functools.lru_cache(maxsize=128)
def _zero_pad_and_compute_fft_cached_impl(
im_tuple, cen_row, cen_col, target_dim, ap_rad
):
return _zero_pad_and_compute_fft_impl(
np.array(im_tuple), cen_row, cen_col, target_dim, ap_rad
)


@functools.wraps(_zero_pad_and_compute_fft_impl)
def _zero_pad_and_compute_fft_cached(im, cen_row, cen_col, target_dim, ap_rad):
return _zero_pad_and_compute_fft_cached_impl(
tuple(tuple(ii) for ii in im),
float(cen_row), float(cen_col), int(target_dim), float(ap_rad)
)


_zero_pad_and_compute_fft_cached.cache_info \
= _zero_pad_and_compute_fft_cached_impl.cache_info
_zero_pad_and_compute_fft_cached.cache_clear \
= _zero_pad_and_compute_fft_cached_impl.cache_clear


def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
"""deconvolve the PSF from an image in place.
Expand All @@ -398,6 +423,7 @@ def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
return kim, kpsf_im, msk


@functools.lru_cache(maxsize=128)
def _ksigma_kernels(
dim,
kernel_size,
Expand Down Expand Up @@ -505,6 +531,7 @@ def _ksigma_kernels(
)


@functools.lru_cache(maxsize=128)
def _gauss_kernels(
dim,
kernel_size,
Expand Down
8 changes: 5 additions & 3 deletions ngmix/tests/_galsim_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import galsim


def _get_obs(rng, set_noise_image=False, noise=1.0e-6):
def _get_obs(rng, set_noise_image=False, noise=1.0e-6, psf_fwhm=0.9, n=None):

psf_noise = 1.0e-6

scale = 0.263

psf_fwhm = 0.9
gal_fwhm = 0.7

psf = galsim.Gaussian(fwhm=psf_fwhm)
Expand All @@ -18,7 +17,10 @@ def _get_obs(rng, set_noise_image=False, noise=1.0e-6):
obj = galsim.Convolve(psf, obj0)

psf_im = psf.drawImage(scale=scale).array
im = obj.drawImage(scale=scale).array
if n is not None:
im = obj.drawImage(scale=scale, nx=n, ny=n).array
else:
im = obj.drawImage(scale=scale).array

psf_im += rng.normal(scale=psf_noise, size=psf_im.shape)
im += rng.normal(scale=noise, size=im.shape)
Expand Down
41 changes: 41 additions & 0 deletions ngmix/tests/test_metacal_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time
import numpy as np
import ngmix
import ngmix.metacal.metacal
from ._galsim_sims import _get_obs
from ..metacal.metacal import _cached_galsim_stuff


def test_metacal_cache():
# first warm up numba
rng = np.random.RandomState(seed=100)
obs = _get_obs(rng, noise=0.005, set_noise_image=True, psf_fwhm=0.8, n=300)
t0 = time.time()
ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"])
t0 = time.time() - t0
print("first time: %r seconds" % t0, flush=True)
print(_cached_galsim_stuff.cache_info(), flush=True)

# now cache it
rng = np.random.RandomState(seed=10)
obs = _get_obs(rng, noise=0.005, set_noise_image=True, n=300)
t1 = time.time()
ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"])
t1 = time.time() - t1
print("second time: %r seconds" % t1, flush=True)
print(_cached_galsim_stuff.cache_info(), flush=True)

# now use cache
rng = np.random.RandomState(seed=10)
obs = _get_obs(rng, noise=0.005, set_noise_image=True, n=300)
t2 = time.time()
ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"])
t2 = time.time() - t2
print("third time: %r seconds (< %r?)" % (t2, t1*0.7), flush=True)
print(_cached_galsim_stuff.cache_info(), flush=True)

# numba should be slower always but we do not care how much
assert t1 < t0

# we expect roughly 30% gains
assert t2 < t1*0.7
Loading

0 comments on commit 0b90f27

Please sign in to comment.