Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF cache FFTs of various quantities #216

Merged
merged 8 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mdet_tests/test_mdet_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ def test_mdet_regression(fname, write=False):
),
}
else:
assert col in ["shear"]
assert col in ["shear", "shear_bands"]


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions ngmix/jacobian/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ def _finish_init(self, row0, col0, dvdrow, dvdcol, dudrow, dudcol):

def __repr__(self):
fmt = (
'row0: %-10.5g col0: %-10.5g dvdrow: %-10.5g '
'dvdcol: %-10.5g dudrow: %-10.5g dudcol: %-10.5g'
'ngmix.Jacobian(row=%r, col=%r, dvdrow=%r, '
'dvdcol=%r, dudrow=%r, dudcol=%r)'
)
return fmt % (self.row0,
self.col0,
Expand Down
33 changes: 22 additions & 11 deletions ngmix/metacal/metacal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import copy
import logging
from functools import lru_cache
import numpy as np
from ..gexceptions import GMixRangeError, BootPSFFailure
from ..shape import Shape
Expand All @@ -20,6 +21,14 @@
logger = logging.getLogger(__name__)


@lru_cache(maxsize=128)
def _cached_galsim_stuff(img, wcs_repr, xinterp):
import galsim
image = galsim.Image(np.array(img), wcs=eval(wcs_repr))
image_int = galsim.InterpolatedImage(image, x_interpolant=xinterp)
return image, image_int


class MetacalDilatePSF(object):
"""
Create manipulated images for use in metacalibration
Expand Down Expand Up @@ -355,22 +364,24 @@ def _set_data(self):
# these would share data with the original numpy arrays, make copies
# to be sure they don't get modified
#
self.image = galsim.Image(obs.image.copy(),
wcs=self.get_wcs())

self.psf_image = galsim.Image(obs.psf.image.copy(),
wcs=self.get_psf_wcs())
image, image_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.image.copy()),
repr(self.get_wcs()),
self.interp,
)
self.image = image
self.image_int = image_int

# interpolated psf image
psf_int = galsim.InterpolatedImage(self.psf_image,
x_interpolant=self.interp)
psf_image, psf_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.psf.image.copy()),
repr(self.get_psf_wcs()),
self.interp,
)
self.psf_image = psf_image

# this can be used to deconvolve the psf from the galaxy image
psf_int_inv = galsim.Deconvolve(psf_int)

self.image_int = galsim.InterpolatedImage(self.image,
x_interpolant=self.interp)

# deconvolved galaxy image, psf+pixel removed
self.image_int_nopsf = galsim.Convolve(self.image_int,
psf_int_inv)
Expand Down
49 changes: 38 additions & 11 deletions ngmix/prepsfmom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import functools

import numpy as np
import scipy.fft as fft
Expand Down Expand Up @@ -84,14 +85,14 @@ def _meas(self, obs, psf_obs, return_kernels):
eff_pad_factor = target_dim / obs.image.shape[0]

# pad image, psf and weight map, get FFTs, apply cen_phases
kim, im_row, im_col = _zero_pad_and_compute_fft(
kim, im_row, im_col = _zero_pad_and_compute_fft_cached(
obs.image, obs.jacobian.row0, obs.jacobian.col0, target_dim,
self.ap_rad,
)
fft_dim = kim.shape[0]

if psf_obs is not None:
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft(
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft_cached(
psf_obs.image,
psf_obs.jacobian.row0, psf_obs.jacobian.col0,
target_dim,
Expand Down Expand Up @@ -128,17 +129,17 @@ def _meas(self, obs, psf_obs, return_kernels):
# now build the kernels
if self.kernel == "ksigma":
kernels = _ksigma_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
elif self.kernel in ["gauss", "pgauss"]:
kernels = _gauss_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
else:
raise ValueError(
Expand Down Expand Up @@ -365,7 +366,7 @@ def _compute_cen_phase_shift(cen_row, cen_col, dim, msk=None):
return np.cos(kcen) + 1j*np.sin(kcen)


def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
def _zero_pad_and_compute_fft_impl(im, cen_row, cen_col, target_dim, ap_rad):
"""zero pad and compute the FFT

Returns the fft, cen_row in the padded image, and cen_col in the padded image.
Expand All @@ -382,6 +383,30 @@ def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
return kpim, pad_cen_row, pad_cen_col


# see https://stackoverflow.com/a/52332109 for how this works
@functools.lru_cache(maxsize=128)
def _zero_pad_and_compute_fft_cached_impl(
im_tuple, cen_row, cen_col, target_dim, ap_rad
):
return _zero_pad_and_compute_fft_impl(
np.array(im_tuple), cen_row, cen_col, target_dim, ap_rad
)


@functools.wraps(_zero_pad_and_compute_fft_impl)
def _zero_pad_and_compute_fft_cached(im, cen_row, cen_col, target_dim, ap_rad):
return _zero_pad_and_compute_fft_cached_impl(
tuple(tuple(ii) for ii in im),
float(cen_row), float(cen_col), int(target_dim), float(ap_rad)
)


_zero_pad_and_compute_fft_cached.cache_info \
= _zero_pad_and_compute_fft_cached_impl.cache_info
_zero_pad_and_compute_fft_cached.cache_clear \
= _zero_pad_and_compute_fft_cached_impl.cache_clear


def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
"""deconvolve the PSF from an image in place.

Expand All @@ -398,6 +423,7 @@ def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
return kim, kpsf_im, msk


@functools.lru_cache(maxsize=128)
def _ksigma_kernels(
dim,
kernel_size,
Expand Down Expand Up @@ -505,6 +531,7 @@ def _ksigma_kernels(
)


@functools.lru_cache(maxsize=128)
def _gauss_kernels(
dim,
kernel_size,
Expand Down
8 changes: 5 additions & 3 deletions ngmix/tests/_galsim_sims.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
import galsim


def _get_obs(rng, set_noise_image=False, noise=1.0e-6):
def _get_obs(rng, set_noise_image=False, noise=1.0e-6, psf_fwhm=0.9, n=None):

psf_noise = 1.0e-6

scale = 0.263

psf_fwhm = 0.9
gal_fwhm = 0.7

psf = galsim.Gaussian(fwhm=psf_fwhm)
Expand All @@ -18,7 +17,10 @@ def _get_obs(rng, set_noise_image=False, noise=1.0e-6):
obj = galsim.Convolve(psf, obj0)

psf_im = psf.drawImage(scale=scale).array
im = obj.drawImage(scale=scale).array
if n is not None:
im = obj.drawImage(scale=scale, nx=n, ny=n).array
else:
im = obj.drawImage(scale=scale).array

psf_im += rng.normal(scale=psf_noise, size=psf_im.shape)
im += rng.normal(scale=noise, size=im.shape)
Expand Down
41 changes: 41 additions & 0 deletions ngmix/tests/test_metacal_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import time
import numpy as np
import ngmix
import ngmix.metacal.metacal
from ._galsim_sims import _get_obs
from ..metacal.metacal import _cached_galsim_stuff


def test_metacal_cache():
# first warm up numba
rng = np.random.RandomState(seed=100)
obs = _get_obs(rng, noise=0.005, set_noise_image=True, psf_fwhm=0.8, n=300)
t0 = time.time()
ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"])
t0 = time.time() - t0
print("first time: %r seconds" % t0, flush=True)
print(_cached_galsim_stuff.cache_info(), flush=True)

# now cache it
rng = np.random.RandomState(seed=10)
obs = _get_obs(rng, noise=0.005, set_noise_image=True, n=300)
t1 = time.time()
ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"])
t1 = time.time() - t1
print("second time: %r seconds" % t1, flush=True)
print(_cached_galsim_stuff.cache_info(), flush=True)

# now use cache
rng = np.random.RandomState(seed=10)
obs = _get_obs(rng, noise=0.005, set_noise_image=True, n=300)
t2 = time.time()
ngmix.metacal.get_all_metacal(obs, rng=rng, types=["noshear"])
t2 = time.time() - t2
print("third time: %r seconds (< %r?)" % (t2, t1*0.7), flush=True)
print(_cached_galsim_stuff.cache_info(), flush=True)

# numba should be slower always but we do not care how much
assert t1 < t0

# we expect roughly 30% gains
assert t2 < t1*0.7
Loading