Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF cache FFTs of various quantities #216

Merged
merged 8 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 22 additions & 11 deletions ngmix/metacal/metacal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import copy
import logging
from functools import lru_cache
import numpy as np
from ..gexceptions import GMixRangeError, BootPSFFailure
from ..shape import Shape
Expand All @@ -20,6 +21,14 @@
logger = logging.getLogger(__name__)


@lru_cache(maxsize=128)
def _cached_galsim_stuff(img, wcs_repr, xinterp):
import galsim
image = galsim.Image(np.array(img), wcs=eval(wcs_repr))
image_int = galsim.InterpolatedImage(image, x_interpolant=xinterp)
return image, image_int


class MetacalDilatePSF(object):
"""
Create manipulated images for use in metacalibration
Expand Down Expand Up @@ -355,22 +364,24 @@ def _set_data(self):
# these would share data with the original numpy arrays, make copies
# to be sure they don't get modified
#
self.image = galsim.Image(obs.image.copy(),
wcs=self.get_wcs())

self.psf_image = galsim.Image(obs.psf.image.copy(),
wcs=self.get_psf_wcs())
image, image_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.image.copy()),
repr(self.get_wcs()),
self.interp,
)
self.image = image
self.image_int = image_int

# interpolated psf image
psf_int = galsim.InterpolatedImage(self.psf_image,
x_interpolant=self.interp)
psf_image, psf_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.psf.image.copy()),
repr(self.get_psf_wcs()),
self.interp,
)
self.psf_image = psf_image

# this can be used to deconvolve the psf from the galaxy image
psf_int_inv = galsim.Deconvolve(psf_int)

self.image_int = galsim.InterpolatedImage(self.image,
x_interpolant=self.interp)

# deconvolved galaxy image, psf+pixel removed
self.image_int_nopsf = galsim.Convolve(self.image_int,
psf_int_inv)
Expand Down
49 changes: 38 additions & 11 deletions ngmix/prepsfmom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import functools

import numpy as np
import scipy.fft as fft
Expand Down Expand Up @@ -84,14 +85,14 @@ def _meas(self, obs, psf_obs, return_kernels):
eff_pad_factor = target_dim / obs.image.shape[0]

# pad image, psf and weight map, get FFTs, apply cen_phases
kim, im_row, im_col = _zero_pad_and_compute_fft(
kim, im_row, im_col = _zero_pad_and_compute_fft_cached(
obs.image, obs.jacobian.row0, obs.jacobian.col0, target_dim,
self.ap_rad,
)
fft_dim = kim.shape[0]

if psf_obs is not None:
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft(
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft_cached(
psf_obs.image,
psf_obs.jacobian.row0, psf_obs.jacobian.col0,
target_dim,
Expand Down Expand Up @@ -128,17 +129,17 @@ def _meas(self, obs, psf_obs, return_kernels):
# now build the kernels
if self.kernel == "ksigma":
kernels = _ksigma_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
elif self.kernel in ["gauss", "pgauss"]:
kernels = _gauss_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
else:
raise ValueError(
Expand Down Expand Up @@ -365,7 +366,7 @@ def _compute_cen_phase_shift(cen_row, cen_col, dim, msk=None):
return np.cos(kcen) + 1j*np.sin(kcen)


def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
def _zero_pad_and_compute_fft_impl(im, cen_row, cen_col, target_dim, ap_rad):
"""zero pad and compute the FFT

Returns the fft, cen_row in the padded image, and cen_col in the padded image.
Expand All @@ -382,6 +383,30 @@ def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
return kpim, pad_cen_row, pad_cen_col


# see https://stackoverflow.com/a/52332109 for how this works
@functools.lru_cache(maxsize=128)
def _zero_pad_and_compute_fft_cached_impl(
im_tuple, cen_row, cen_col, target_dim, ap_rad
):
return _zero_pad_and_compute_fft_impl(
np.array(im_tuple), cen_row, cen_col, target_dim, ap_rad
)


@functools.wraps(_zero_pad_and_compute_fft_impl)
def _zero_pad_and_compute_fft_cached(im, cen_row, cen_col, target_dim, ap_rad):
return _zero_pad_and_compute_fft_cached_impl(
tuple(tuple(ii) for ii in im),
float(cen_row), float(cen_col), int(target_dim), float(ap_rad)
)


_zero_pad_and_compute_fft_cached.cache_info \
= _zero_pad_and_compute_fft_cached_impl.cache_info
_zero_pad_and_compute_fft_cached.cache_clear \
= _zero_pad_and_compute_fft_cached_impl.cache_clear


def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
"""deconvolve the PSF from an image in place.

Expand All @@ -398,6 +423,7 @@ def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
return kim, kpsf_im, msk


@functools.lru_cache(maxsize=128)
def _ksigma_kernels(
dim,
kernel_size,
Expand Down Expand Up @@ -505,6 +531,7 @@ def _ksigma_kernels(
)


@functools.lru_cache(maxsize=128)
def _gauss_kernels(
dim,
kernel_size,
Expand Down
133 changes: 133 additions & 0 deletions ngmix/tests/test_prepsfmom.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import galsim
import numpy as np
import pytest
import time
from flaky import flaky

from ngmix.prepsfmom import (
KSigmaMom, PGaussMom,
_build_square_apodization_mask,
PrePSFMom,
_gauss_kernels,
_zero_pad_and_compute_fft_cached_impl,
)
from ngmix import Jacobian
from ngmix import Observation
Expand Down Expand Up @@ -94,6 +98,135 @@ def test_prepsfmom_raises_badjacob(cls):
assert "same WCS Jacobia" in str(e.value)


@flaky
def test_prepsfmom_speed_and_cache():
image_size = 48
psf_image_size = 53
pixel_scale = 0.263
fwhm = 0.9
psf_fwhm = 0.9
snr = 20
mom_fwhm = 2

rng = np.random.RandomState(seed=100)

cen = (image_size - 1)/2
psf_cen = (psf_image_size - 1)/2
gs_wcs = galsim.ShearWCS(
pixel_scale, galsim.Shear(g1=-0.1, g2=0.06)).jacobian()
scale = np.sqrt(gs_wcs.pixelArea())
shift = rng.uniform(low=-scale/2, high=scale/2, size=2)
psf_shift = rng.uniform(low=-scale/2, high=scale/2, size=2)
xy = gs_wcs.toImage(galsim.PositionD(shift))
psf_xy = gs_wcs.toImage(galsim.PositionD(psf_shift))

jac = Jacobian(
y=cen + xy.y, x=cen + xy.x,
dudx=gs_wcs.dudx, dudy=gs_wcs.dudy,
dvdx=gs_wcs.dvdx, dvdy=gs_wcs.dvdy)

psf_jac = Jacobian(
y=psf_cen + psf_xy.y, x=psf_cen + psf_xy.x,
dudx=gs_wcs.dudx, dudy=gs_wcs.dudy,
dvdx=gs_wcs.dvdx, dvdy=gs_wcs.dvdy)

gal = galsim.Gaussian(
fwhm=fwhm
).shear(
g1=-0.1, g2=0.2
).withFlux(
400
).shift(
dx=shift[0], dy=shift[1]
)
psf = galsim.Gaussian(
fwhm=psf_fwhm
).shear(
g1=0.3, g2=-0.15
)
im = galsim.Convolve([gal, psf]).drawImage(
nx=image_size,
ny=image_size,
wcs=gs_wcs
).array
noise = np.sqrt(np.sum(im**2)) / snr
wgt = np.ones_like(im) / noise**2

psf_im = psf.shift(
dx=psf_shift[0], dy=psf_shift[1]
).drawImage(
nx=psf_image_size,
ny=psf_image_size,
wcs=gs_wcs
).array

# now we test the speed + caching
_gauss_kernels.cache_clear()
_zero_pad_and_compute_fft_cached_impl.cache_clear()

# the first fit will do numba stuff, so we exclude it
# we also perturb the various inputs to fool our caches
fitter = PGaussMom(
fwhm=mom_fwhm + 1e-3,
)

obs = Observation(
image=im + 1e-6,
weight=wgt,
jacobian=jac,
psf=Observation(image=psf_im + 1e-8, jacobian=psf_jac),
)

dt = time.time()
fitter.go(obs=obs)
dt1 = time.time() - dt
print("\n%0.4f ms for first fit" % (dt1*1000))

# we miss once here for kernels, twice for images
assert _gauss_kernels.cache_info().misses == 1
assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 2

# the second fit will have numba cached, but not the other kernel and FFT caches
fitter = PGaussMom(
fwhm=mom_fwhm,
)

obs = Observation(
image=im,
weight=wgt,
jacobian=jac,
psf=Observation(image=psf_im, jacobian=psf_jac),
)

dt = time.time()
fitter.go(obs=obs)
dt2 = time.time() - dt
print("%0.4f ms for second fit" % (dt2*1000))

# we miss twice for kernels, total of 3 times since psf changed
assert _gauss_kernels.cache_info().misses == 2
assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 4

# now we test with full caching
nfit = 1000
dt = time.time()
for _ in range(nfit):
with obs.writeable():
obs.image += 1e-6
fitter.go(obs=obs)
dt3 = time.time() - dt

print("%0.4f ms per fit" % (dt3/nfit*1000))

# we should never miss again for the calls above
assert _gauss_kernels.cache_info().misses == 2
assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 4 + nfit

# if numba stuff is cached this does not work so commented out
# assert dt2 < dt1
assert dt3/nfit < dt2
beckermr marked this conversation as resolved.
Show resolved Hide resolved
beckermr marked this conversation as resolved.
Show resolved Hide resolved


def _stack_list_of_dicts(res):
def _get_dtype(v):
if isinstance(v, float):
Expand Down