diff --git a/ngmix/prepsfmom.py b/ngmix/prepsfmom.py index 4604e1f7..c015e9bc 100644 --- a/ngmix/prepsfmom.py +++ b/ngmix/prepsfmom.py @@ -276,6 +276,23 @@ def _measure_moments_fft( cen_phase = _compute_cen_phase_shift(drow, dcol, dim, msk=msk) kim *= cen_phase + # we only sum where the kernel is nonzero + fkf = kernels["fkf"] + fkr = kernels["fkr"] + fkp = kernels["fkp"] + fkc = kernels["fkc"] + + mom_norm = kernels["fk00"] + return _numba_bits( + kim, fkf, fkr, fkp, fkc, eff_pad_factor, kpsf_im, mom_norm, tot_var, dim, + ) + + +@njit +def _numba_bits( + kim, fkf, fkr, fkp, fkc, eff_pad_factor, kpsf_im, mom_norm, tot_var, dim, +): + # build the flux, radial, plus and cross kernels / moments # the inverse FFT in our convention has a factor of 1/n per dimension # the sums below are inverse FFTs but only computing the values at the @@ -285,13 +302,6 @@ def _measure_moments_fft( df2 = df * df df4 = df2 * df2 - # we only sum where the kernel is nonzero - fkf = kernels["fkf"] - fkr = kernels["fkr"] - fkp = kernels["fkp"] - fkc = kernels["fkc"] - - mom_norm = kernels["fk00"] mf = np.sum((kim * fkf).real) * df2 mr = np.sum((kim * fkr).real) * df2 mp = np.sum((kim * fkp).real) * df2 @@ -397,10 +407,16 @@ def _compute_cen_phase_shift(cen_row, cen_col, dim, msk=None): fx = f.reshape(1, -1) fy = f.reshape(-1, 1) kcen = fy*cen_row + fx*cen_col + if msk is not None: - return np.cos(kcen[msk]) + 1j*np.sin(kcen[msk]) - else: - return np.cos(kcen) + 1j*np.sin(kcen) + kcen = kcen[msk] + + return _comp_phase(kcen) + + +@njit +def _comp_phase(kcen): + return np.cos(kcen) + 1j*np.sin(kcen) def _zero_pad_and_compute_fft_impl(im, cen_row, cen_col, target_dim, ap_rad): diff --git a/ngmix/tests/test_prepsfmom.py b/ngmix/tests/test_prepsfmom.py index 7b4cdd77..6488b9a0 100644 --- a/ngmix/tests/test_prepsfmom.py +++ b/ngmix/tests/test_prepsfmom.py @@ -211,7 +211,7 @@ def test_prepsfmom_speed_and_cache(): assert _zero_pad_and_compute_fft_cached_impl.cache_info().misses == 4 # now we test with full caching - nfit = 1000 + nfit = 2000 dt = time.time() for _ in range(nfit): with obs.writeable():