Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH add extra conv and deconv ops to prepsf moments #215

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions ngmix/metacal/metacal.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
import copy
import logging
from functools import lru_cache

import numpy as np
from ..gexceptions import GMixRangeError, BootPSFFailure
from ..shape import Shape
Expand All @@ -20,6 +22,14 @@
logger = logging.getLogger(__name__)


@lru_cache(maxsize=128)
def _cached_galsim_stuff(img, wcs_repr, xinterp):
import galsim
image = galsim.Image(np.array(img), wcs=eval(wcs_repr))
image_int = galsim.InterpolatedImage(image, x_interpolant=xinterp)
return image, image_int


class MetacalDilatePSF(object):
"""
Create manipulated images for use in metacalibration
Expand Down Expand Up @@ -355,22 +365,24 @@ def _set_data(self):
# these would share data with the original numpy arrays, make copies
# to be sure they don't get modified
#
self.image = galsim.Image(obs.image.copy(),
wcs=self.get_wcs())

self.psf_image = galsim.Image(obs.psf.image.copy(),
wcs=self.get_psf_wcs())
image, image_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.image.copy()),
repr(self.get_wcs()),
self.interp,
)
self.image = image
self.image_int = image_int

# interpolated psf image
psf_int = galsim.InterpolatedImage(self.psf_image,
x_interpolant=self.interp)
psf_image, psf_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.psf.image.copy()),
repr(self.get_psf_wcs()),
self.interp,
)
self.psf_image = psf_image

# this can be used to deconvolve the psf from the galaxy image
psf_int_inv = galsim.Deconvolve(psf_int)

self.image_int = galsim.InterpolatedImage(self.image,
x_interpolant=self.interp)

# deconvolved galaxy image, psf+pixel removed
self.image_int_nopsf = galsim.Convolve(self.image_int,
psf_int_inv)
Expand Down
230 changes: 170 additions & 60 deletions ngmix/prepsfmom.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import functools

import numpy as np
import scipy.fft as fft
Expand Down Expand Up @@ -51,7 +52,10 @@ def __init__(self, fwhm, kernel, pad_factor=4, ap_rad=1.5):
"The kernel '%s' for PrePSFMom is not recognized!" % self.kernel
)

def go(self, obs, return_kernels=False, no_psf=False):
def go(
self, obs, return_kernels=False, no_psf=False,
extra_conv_psfs=None, extra_deconv_psfs=None,
):
"""Measure the pre-PSF ksigma moments.

Parameters
Expand All @@ -64,23 +68,34 @@ def go(self, obs, return_kernels=False, no_psf=False):
no_psf : bool, optional
If True, allow inputs without a PSF observation. Defaults to False
so that any input observation without a PSF will raise an error.
extra_conv_psfs : list of ngmix obs, optional
If specified, these PSFs will be convolved into the image before the moments
are measured.
extra_deconv_psfs : list of ngmix obs, optional
If specified, these PSFs will be deconvolved from the image before the
moments are measured.

Returns
-------
result dictionary
"""
psf_obs = _check_obs_and_get_psf_obs(obs, no_psf)
return self._meas(obs, psf_obs, return_kernels)

def _meas(self, obs, psf_obs, return_kernels):
# pick the larger size
if psf_obs is not None:
if obs.image.shape[0] > psf_obs.image.shape[0]:
target_dim = int(obs.image.shape[0] * self.pad_factor)
else:
target_dim = int(psf_obs.image.shape[0] * self.pad_factor)
else:
target_dim = int(obs.image.shape[0] * self.pad_factor)
conv_psf_obs_list, deconv_psf_obs_list = _check_obs_and_get_psf_obs(
obs, no_psf,
extra_conv_psfs=extra_conv_psfs,
extra_deconv_psfs=extra_deconv_psfs,
)
return self._meas(obs, conv_psf_obs_list, deconv_psf_obs_list, return_kernels)

def _meas(self, obs, conv_psf_obs_list, deconv_psf_obs_list, return_kernels):
# pick the largest size
target_dim = max(
max([o.image.shape[0] for o in conv_psf_obs_list])
if conv_psf_obs_list else -1,
max([o.image.shape[0] for o in deconv_psf_obs_list])
if deconv_psf_obs_list else -1,
obs.image.shape[0],
)
target_dim = int(target_dim * self.pad_factor)
eff_pad_factor = target_dim / obs.image.shape[0]

# pad image, psf and weight map, get FFTs, apply cen_phases
Expand All @@ -90,18 +105,17 @@ def _meas(self, obs, psf_obs, return_kernels):
)
fft_dim = kim.shape[0]

if psf_obs is not None:
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft(
psf_obs.image,
psf_obs.jacobian.row0, psf_obs.jacobian.col0,
target_dim,
0, # we do not apodize PSF stamps since it should not be needed
)
else:
# delta function in k-space
kpsf_im = np.ones_like(kim, dtype=np.complex128)
psf_im_row = 0.0
psf_im_col = 0.0
(
deconv_kpsf_im,
deconv_psf_im_row,
deconv_psf_im_col,
) = _zero_pad_and_compute_fft_cached_list(deconv_psf_obs_list, kim, target_dim)

(
conv_kpsf_im,
conv_psf_im_row,
conv_psf_im_col,
) = _zero_pad_and_compute_fft_cached_list(conv_psf_obs_list, kim, target_dim)

# the final, deconvolved image we want is
#
Expand All @@ -128,17 +142,17 @@ def _meas(self, obs, psf_obs, return_kernels):
# now build the kernels
if self.kernel == "ksigma":
kernels = _ksigma_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
elif self.kernel in ["gauss", "pgauss"]:
kernels = _gauss_kernels(
target_dim,
self.fwhm,
obs.jacobian.dvdrow, obs.jacobian.dvdcol,
obs.jacobian.dudrow, obs.jacobian.dudcol,
int(target_dim),
float(self.fwhm),
float(obs.jacobian.dvdrow), float(obs.jacobian.dvdcol),
float(obs.jacobian.dudrow), float(obs.jacobian.dudcol),
)
else:
raise ValueError(
Expand All @@ -151,8 +165,11 @@ def _meas(self, obs, psf_obs, return_kernels):

# run the actual measurements and return
mom, mom_cov = _measure_moments_fft(
kim, kpsf_im, tot_var, eff_pad_factor, kernels,
im_row - psf_im_row, im_col - psf_im_col,
kim,
deconv_kpsf_im, conv_kpsf_im,
tot_var, eff_pad_factor, kernels,
im_row - deconv_psf_im_row + conv_psf_im_row,
im_col - deconv_psf_im_col + conv_psf_im_col,
)
res = make_mom_result(mom, mom_cov)
if res['flags'] != 0:
Expand Down Expand Up @@ -225,19 +242,27 @@ def __init__(self, fwhm, pad_factor=4, ap_rad=1.5):
PrePSFGaussMom = PGaussMom


def _measure_moments_fft(kim, kpsf_im, tot_var, eff_pad_factor, kernels, drow, dcol):
def _measure_moments_fft(
kim, deconv_kpsf_im, conv_kpsf_im, tot_var, eff_pad_factor, kernels,
drow, dcol
):
# we only need to do things where the kernel is non-zero
# this saves a bunch of CPU cycles
msk = kernels["msk"]
dim = kim.shape[0]

# deconvolve PSF
kim, kpsf_im, _ = _deconvolve_im_psf_inplace(
kim[msk],
kpsf_im[msk],
# max amplitude is flux which is 0,0 in the standard FFT convention
np.abs(kpsf_im[0, 0]),
)
if deconv_kpsf_im is not None or conv_kpsf_im is not None:
kim, inv_kpsf_im, _ = _deconvolve_im_psf_inplace(
kim[msk],
deconv_kpsf_im[msk] if deconv_kpsf_im is not None else None,
# max amplitude is flux which is 0,0 in the standard FFT convention
np.abs(deconv_kpsf_im[0, 0]) if deconv_kpsf_im is not None else None,
conv_kpsf_im[msk] if conv_kpsf_im is not None else None,
)
else:
kim = kim[msk]
inv_kpsf_im = np.ones_like(kim, dtype=np.complex128)

# put in phase shift as described above
# the sin and cos are expensive so we only compute them where we will
Expand Down Expand Up @@ -278,7 +303,7 @@ def _measure_moments_fft(kim, kpsf_im, tot_var, eff_pad_factor, kernels, drow, d
m_cov[1, 1] = 1
tot_var *= eff_pad_factor**2
tot_var_df4 = tot_var * df4
kerns = [fkp / kpsf_im, fkc / kpsf_im, fkr / kpsf_im, fkf / kpsf_im]
kerns = [fkp * inv_kpsf_im, fkc * inv_kpsf_im, fkr * inv_kpsf_im, fkf * inv_kpsf_im]
conj_kerns = [np.conj(k) for k in kerns]
for i in range(2, 6):
for j in range(i, 6):
Expand Down Expand Up @@ -382,22 +407,92 @@ def _zero_pad_and_compute_fft(im, cen_row, cen_col, target_dim, ap_rad):
return kpim, pad_cen_row, pad_cen_col


def _deconvolve_im_psf_inplace(kim, kpsf_im, max_amp, min_psf_frac=1e-5):
# see https://stackoverflow.com/a/52332109 for how this works
@functools.lru_cache(maxsize=128)
def _zero_pad_and_compute_fft_cached_impl(
im_tuple, cen_row, cen_col, target_dim, ap_rad
):
return _zero_pad_and_compute_fft(
np.array(im_tuple), cen_row, cen_col, target_dim, ap_rad
)


@functools.wraps(_zero_pad_and_compute_fft)
def _zero_pad_and_compute_fft_cached(im, cen_row, cen_col, target_dim, ap_rad):
return _zero_pad_and_compute_fft_cached_impl(
tuple(tuple(ii) for ii in im),
float(cen_row), float(cen_col), int(target_dim), float(ap_rad)
)


_zero_pad_and_compute_fft_cached.cache_info \
= _zero_pad_and_compute_fft_cached_impl.cache_info
_zero_pad_and_compute_fft_cached.cache_clear \
= _zero_pad_and_compute_fft_cached_impl.cache_clear


def _zero_pad_and_compute_fft_cached_list(psf_obs_list, kim, target_dim):
if len(psf_obs_list) == 0:
# delta function in k-space
kpsf_im = None
psf_im_row = 0.0
psf_im_col = 0.0
elif len(psf_obs_list) == 1:
return _zero_pad_and_compute_fft_cached(
psf_obs_list[0].image,
psf_obs_list[0].jacobian.row0, psf_obs_list[0].jacobian.col0,
target_dim,
0, # we do not apodize PSF stamps since it should not be needed
)
else:
fft_res = [
_zero_pad_and_compute_fft_cached(
psf_obs.image,
psf_obs.jacobian.row0, psf_obs.jacobian.col0,
target_dim,
0, # we do not apodize PSF stamps since it should not be needed
)
for psf_obs in psf_obs_list
]
kpsf_im = np.prod([f[0] for f in fft_res], axis=0)
psf_im_row = sum([f[1] for f in fft_res])
psf_im_col = sum([f[2] for f in fft_res])

return kpsf_im, psf_im_row, psf_im_col


def _deconvolve_im_psf_inplace(
kim, deconv_kpsf_im, max_amp, conv_kpsf_im, min_psf_frac=1e-5
):
"""deconvolve the PSF from an image in place.

Returns the deconvolved image, the kpsf_im used,
and a bool mask marking PSF modes that were truncated
"""
min_amp = min_psf_frac * max_amp
abs_kpsf_im = np.abs(kpsf_im)
msk = abs_kpsf_im <= min_amp
if np.any(msk):
kpsf_im[msk] = kpsf_im[msk] / abs_kpsf_im[msk] * min_amp
if deconv_kpsf_im is not None:
abs_kpsf_im = np.abs(deconv_kpsf_im)
msk = abs_kpsf_im <= min_amp
if np.any(msk):
deconv_kpsf_im[msk] = deconv_kpsf_im[msk] / abs_kpsf_im[msk] * min_amp

if conv_kpsf_im is not None:
inv_kpsf_im = conv_kpsf_im / deconv_kpsf_im
else:
inv_kpsf_im = 1.0 / deconv_kpsf_im
else:
msk = np.ones_like(kim, dtype=bool)

if conv_kpsf_im is not None:
inv_kpsf_im = conv_kpsf_im
else:
inv_kpsf_im = np.ones_like(kim, dtype=np.complex128)

kim /= kpsf_im
return kim, kpsf_im, msk
kim *= inv_kpsf_im
return kim, inv_kpsf_im, msk


@functools.lru_cache(maxsize=128)
def _ksigma_kernels(
dim,
kernel_size,
Expand Down Expand Up @@ -505,6 +600,7 @@ def _ksigma_kernels(
)


@functools.lru_cache(maxsize=128)
def _gauss_kernels(
dim,
kernel_size,
Expand Down Expand Up @@ -596,7 +692,16 @@ def _gauss_kernels(
)


def _check_obs_and_get_psf_obs(obs, no_psf):
def _jacobian_close(jac1, jac2):
return np.allclose(
[jac1.dudcol, jac1.dudrow, jac1.dvdcol, jac1.dvdrow],
[jac2.dudcol, jac2.dudrow, jac2.dvdcol, jac2.dvdrow]
)


def _check_obs_and_get_psf_obs(
obs, no_psf, extra_conv_psfs=None, extra_deconv_psfs=None,
):
if not isinstance(obs, Observation):
raise ValueError("input obs must be an Observation")

Expand All @@ -608,14 +713,19 @@ def _check_obs_and_get_psf_obs(obs, no_psf):
raise RuntimeError("The PSF must be set to measure a pre-PSF moment!")

if not no_psf:
psf_obs = obs.get_psf()

if psf_obs.jacobian.get_galsim_wcs() != obs.jacobian.get_galsim_wcs():
raise RuntimeError(
"The PSF and observation must have the same WCS "
"Jacobian for measuring pre-PSF moments."
)
conv_psfs = extra_conv_psfs or []
deconv_psfs = [obs.get_psf()] + (extra_deconv_psfs or [])
else:
psf_obs = None
conv_psfs = extra_conv_psfs or []
deconv_psfs = extra_deconv_psfs or []

if any(
not _jacobian_close(psf_obs.jacobian, obs.jacobian)
for psf_obs in conv_psfs + deconv_psfs
):
raise RuntimeError(
"The PSF and observation must have the same WCS "
"Jacobian for measuring pre-PSF moments."
)

return psf_obs
return conv_psfs, deconv_psfs
Loading