Skip to content

Commit

Permalink
Merge pull request #227 from esheldon/ppsf-perf
Browse files Browse the repository at this point in the history
PERF add O(N) phase shift computation
  • Loading branch information
beckermr authored Aug 24, 2022
2 parents 892fd4a + 8de09b7 commit ad8e8ac
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 9 deletions.
3 changes: 2 additions & 1 deletion CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
## unreleased
## v2.2.0

### new features

- Added function to regularize moments results.
- Added numba to pre-PSF moments to reduce runtime.


## v2.1.0
Expand Down
2 changes: 1 addition & 1 deletion ngmix/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.1.0' # noqa
__version__ = '2.2.0' # noqa
50 changes: 44 additions & 6 deletions ngmix/prepsfmom.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,22 @@ def _measure_moments_fft(
cen_phase = _compute_cen_phase_shift(drow, dcol, dim, msk=msk)
kim *= cen_phase

fkf = kernels["fkf"]
fkr = kernels["fkr"]
fkp = kernels["fkp"]
fkc = kernels["fkc"]

mom_norm = kernels["fk00"]

return _measure_moments_fft_numba(
kim, kpsf_im, dim, eff_pad_factor, fkf, fkr, fkp, fkc, mom_norm, tot_var,
)


@njit
def _measure_moments_fft_numba(
kim, kpsf_im, dim, eff_pad_factor, fkf, fkr, fkp, fkc, mom_norm, tot_var,
):
# build the flux, radial, plus and cross kernels / moments
# the inverse FFT in our convention has a factor of 1/n per dimension
# the sums below are inverse FFTs but only computing the values at the
Expand All @@ -286,12 +302,6 @@ def _measure_moments_fft(
df4 = df2 * df2

# we only sum where the kernel is nonzero
fkf = kernels["fkf"]
fkr = kernels["fkr"]
fkp = kernels["fkp"]
fkc = kernels["fkc"]

mom_norm = kernels["fk00"]
mf = np.sum((kim * fkf).real) * df2
mr = np.sum((kim * fkr).real) * df2
mp = np.sum((kim * fkp).real) * df2
Expand Down Expand Up @@ -389,6 +399,34 @@ def _zero_pad_image(im, target_dim):
def _compute_cen_phase_shift(cen_row, cen_col, dim, msk=None):
"""computes exp(i*2*pi*k*cen) for shifting the phases of FFTS.
If you feed the centroid of a profile, then this factor times the raw FFT
of that profile will result in an FFT centered at the profile.
"""
f = fft.fftfreq(dim)
pxy = _compute_cen_phase_shift_numba(f, cen_row, cen_col)

if msk is not None:
pxy = pxy[msk]

return pxy


@njit
def _compute_cen_phase_shift_numba(f, cen_row, cen_col):
# this reshaping makes sure the arrays broadcast nicely into a grid
fx = f.reshape(1, -1)
fy = f.reshape(-1, 1)
kcen_x = fx * (2.0 * np.pi * cen_col)
kcen_y = fy * (2.0 * np.pi * cen_row)
px = np.cos(kcen_x) + 1j*np.sin(kcen_x)
py = np.cos(kcen_y) + 1j*np.sin(kcen_y)
pxy = px * py
return pxy


def _compute_cen_phase_shift_orig(cen_row, cen_col, dim, msk=None):
"""computes exp(i*2*pi*k*cen) for shifting the phases of FFTS.
If you feed the centroid of a profile, then this factor times the raw FFT
of that profile will result in an FFT centered at the profile.
"""
Expand Down
2 changes: 1 addition & 1 deletion ngmix/tests/test_leastsqbound.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def test_leastsqbound_smoke(use_prior):
def test_leastsqbound_bounds(fracdev_bounds):
rng = np.random.RandomState(2830)

ntrial = 10
ntrial = 100
fit_model = 'bd'
scale = 0.263

Expand Down
21 changes: 21 additions & 0 deletions ngmix/tests/test_prepsfmom.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,34 @@
PrePSFMom,
_gauss_kernels,
_zero_pad_and_compute_fft_cached_impl,
_compute_cen_phase_shift,
_compute_cen_phase_shift_orig,
)
from ngmix import Jacobian
from ngmix import Observation
from ngmix.moments import make_mom_result
import ngmix.flags


@pytest.mark.parametrize("row", [-0.4, 0, 1.2, 4.5])
@pytest.mark.parametrize("col", [-0.32434, 0, 1.43232, 4.56775])
@pytest.mark.parametrize("dim,msk", [
(100, None),
(453, None),
(3, np.array([[True, False, True], [True, True, False], [True, True, True]])),
(4, np.array([
[False, True, False, True],
[False, True, True, False],
[True, True, True, True],
[False, False, True, False]])),
])
def test_cen_phase_shift(row, col, msk, dim):
np.testing.assert_allclose(
_compute_cen_phase_shift(row, col, dim, msk=msk),
_compute_cen_phase_shift_orig(row, col, dim, msk=msk)
)


def _report_info(s, arr, mn, err):
if mn is not None and err is not None:
print(
Expand Down

0 comments on commit ad8e8ac

Please sign in to comment.