Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH make caching optional #231

Merged
merged 6 commits into from
Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
name: tests
strategy:
matrix:
pyver: [3.7, 3.8, 3.9]
pyver: ["3.8", "3.9", "3.10"]

runs-on: "ubuntu-latest"

Expand Down
7 changes: 7 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
## v2.3.0

### new features

- Caching in pre-psf moments and metacal is now optional
with an API to turn it on. Default is off.

## v2.2.1

### New Features
Expand Down
2 changes: 1 addition & 1 deletion ngmix/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.2.1' # noqa
__version__ = '2.3.0' # noqa
48 changes: 39 additions & 9 deletions ngmix/metacal/metacal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,44 @@
logger = logging.getLogger(__name__)


@lru_cache(maxsize=128)
def _cached_galsim_stuff(img, wcs_repr, xinterp):
USE_GALSIM_CACHE = False


def turn_on_galsim_caching():
global USE_GALSIM_CACHE
USE_GALSIM_CACHE = True


def turn_off_galsim_caching():
global USE_GALSIM_CACHE
USE_GALSIM_CACHE = False
_cached_galsim_stuff.cache_clear()


def _galsim_stuff(img, wcs, xinterp):
if USE_GALSIM_CACHE:
return _cached_galsim_stuff(
tuple(tuple(ii) for ii in img),
repr(wcs),
xinterp,
)
else:
return _galsim_stuff_impl(img, wcs, xinterp)


def _galsim_stuff_impl(img, wcs, xinterp):
import galsim
image = galsim.Image(np.array(img), wcs=eval(wcs_repr))
image = galsim.Image(img, wcs=wcs)
image_int = galsim.InterpolatedImage(image, x_interpolant=xinterp)
return image, image_int


@lru_cache(maxsize=128)
def _cached_galsim_stuff(img, wcs_repr, xinterp):
import galsim # noqa
return _galsim_stuff_impl(np.array(img), eval(wcs_repr), xinterp)


class MetacalDilatePSF(object):
"""
Create manipulated images for use in metacalibration
Expand Down Expand Up @@ -364,17 +394,17 @@ def _set_data(self):
# these would share data with the original numpy arrays, make copies
# to be sure they don't get modified
#
image, image_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.image.copy()),
repr(self.get_wcs()),
image, image_int = _galsim_stuff(
obs.image.copy(),
self.get_wcs(),
self.interp,
)
self.image = image
self.image_int = image_int

psf_image, psf_int = _cached_galsim_stuff(
tuple(tuple(ii) for ii in obs.psf.image.copy()),
repr(self.get_psf_wcs()),
psf_image, psf_int = _galsim_stuff(
obs.psf.image.copy(),
self.get_psf_wcs(),
self.interp,
)
self.psf_image = psf_image
Expand Down
125 changes: 114 additions & 11 deletions ngmix/prepsfmom.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,32 @@

logger = logging.getLogger(__name__)

USE_FFT_CACHE = False
USE_KERNEL_CACHE = False


def turn_on_fft_caching():
global USE_FFT_CACHE
USE_FFT_CACHE = True


def turn_off_fft_caching():
global USE_FFT_CACHE
USE_FFT_CACHE = False
_zero_pad_and_compute_fft_cached_impl.cache_clear()


def turn_on_kernel_caching():
global USE_KERNEL_CACHE
USE_KERNEL_CACHE = True


def turn_off_kernel_caching():
global USE_KERNEL_CACHE
USE_KERNEL_CACHE = False
_gauss_kernels_cached.cache_clear()
_ksigma_kernels_cached.cache_clear()


class PrePSFMom(object):
"""Measure pre-PSF weighted real-space moments.
Expand Down Expand Up @@ -92,14 +118,14 @@ def _meas(self, obs, psf_obs, return_kernels):
eff_pad_factor = target_dim / obs.image.shape[0]

# pad image, psf and weight map, get FFTs, apply cen_phases
kim, im_row, im_col = _zero_pad_and_compute_fft_cached(
kim, im_row, im_col = _zero_pad_and_compute_fft_maybe_cached(
obs.image, obs.jacobian.row0, obs.jacobian.col0, target_dim,
self.ap_rad,
)
fft_dim = kim.shape[0]

if psf_obs is not None:
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft_cached(
kpsf_im, psf_im_row, psf_im_col = _zero_pad_and_compute_fft_maybe_cached(
psf_obs.image,
psf_obs.jacobian.row0, psf_obs.jacobian.col0,
target_dim,
Expand Down Expand Up @@ -469,16 +495,21 @@ def _zero_pad_and_compute_fft_cached_impl(


@functools.wraps(_zero_pad_and_compute_fft_impl)
def _zero_pad_and_compute_fft_cached(im, cen_row, cen_col, target_dim, ap_rad):
return _zero_pad_and_compute_fft_cached_impl(
tuple(tuple(ii) for ii in im),
float(cen_row), float(cen_col), int(target_dim), float(ap_rad)
)
def _zero_pad_and_compute_fft_maybe_cached(im, cen_row, cen_col, target_dim, ap_rad):
if USE_FFT_CACHE:
return _zero_pad_and_compute_fft_cached_impl(
tuple(tuple(ii) for ii in im),
float(cen_row), float(cen_col), int(target_dim), float(ap_rad)
)
else:
return _zero_pad_and_compute_fft_impl(
im, cen_row, cen_col, target_dim, ap_rad,
)


_zero_pad_and_compute_fft_cached.cache_info \
_zero_pad_and_compute_fft_maybe_cached.cache_info \
= _zero_pad_and_compute_fft_cached_impl.cache_info
_zero_pad_and_compute_fft_cached.cache_clear \
_zero_pad_and_compute_fft_maybe_cached.cache_clear \
= _zero_pad_and_compute_fft_cached_impl.cache_clear


Expand Down Expand Up @@ -507,12 +538,48 @@ def _get_fwhm_smooth_profile(fwhm_smooth, fmag2):
return exp_val_smooth


@functools.lru_cache(maxsize=128)
def _ksigma_kernels(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
):
if USE_KERNEL_CACHE:
return _ksigma_kernels_cached(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
)
else:
return _ksigma_kernels_impl(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
)


@functools.lru_cache(maxsize=128)
def _ksigma_kernels_cached(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
):
return _ksigma_kernels_impl(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
)


def _ksigma_kernels_impl(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
):
"""This function builds a ksigma kernel in Fourier-space.

Expand Down Expand Up @@ -624,12 +691,48 @@ def _ksigma_kernels(
)


@functools.lru_cache(maxsize=128)
def _gauss_kernels(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
):
if USE_KERNEL_CACHE:
return _gauss_kernels_cached(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
)
else:
return _gauss_kernels_impl(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
)


@functools.lru_cache(maxsize=128)
def _gauss_kernels_cached(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
):
return _gauss_kernels_impl(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
)


def _gauss_kernels_impl(
dim,
kernel_size,
dvdrow, dvdcol, dudrow, dudcol,
fwhm_smooth,
):
"""This function builds a Gaussian kernel in Fourier-space.

Expand Down
31 changes: 31 additions & 0 deletions ngmix/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import ngmix.prepsfmom
import ngmix.metacal.metacal

import pytest


@pytest.fixture(
scope="module",
params=[(False, False), (True, False), (False, True), (True, True)],
)
def prepsfmom_caching(request):
if request.param[0]:
ngmix.prepsfmom.turn_on_fft_caching()
else:
ngmix.prepsfmom.turn_off_fft_caching()

if request.param[1]:
ngmix.prepsfmom.turn_on_kernel_caching()
else:
ngmix.prepsfmom.turn_off_kernel_caching()


@pytest.fixture(
scope="module",
params=[False, True],
)
def metacal_caching(request):
if request.param:
ngmix.metacal.metacal.turn_on_galsim_caching()
else:
ngmix.metacal.metacal.turn_off_galsim_caching()
16 changes: 8 additions & 8 deletions ngmix/tests/test_metacal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pytest.mark.parametrize('psf', ['gauss', 'fitgauss', 'galsim_obj', 'dilate'])
def test_metacal_smoke(psf):
def test_metacal_smoke(psf, metacal_caching):
rng = np.random.RandomState(seed=100)

obs = _get_obs(rng, noise=0.005)
Expand Down Expand Up @@ -44,7 +44,7 @@ def test_metacal_smoke(psf):

@pytest.mark.parametrize('psf', ['gauss', 'fitgauss', 'galsim_obj'])
@pytest.mark.parametrize('send_rng', [True, False])
def test_metacal_send_rng(psf, send_rng):
def test_metacal_send_rng(psf, send_rng, metacal_caching):

rng = np.random.RandomState(seed=100)
obs = _get_obs(rng, noise=0.005)
Expand Down Expand Up @@ -77,7 +77,7 @@ def test_metacal_send_rng(psf, send_rng):


@pytest.mark.parametrize('psf', ['gauss', 'fitgauss', 'galsim_obj', 'dilate'])
def test_metacal_types_smoke(psf):
def test_metacal_types_smoke(psf, metacal_caching):
rng = np.random.RandomState(seed=100)

obs = _get_obs(rng, noise=0.005)
Expand All @@ -104,7 +104,7 @@ def test_metacal_types_smoke(psf):

@pytest.mark.parametrize('otype', ['obs', 'obslist', 'mbobs'])
@pytest.mark.parametrize('set_noise_image', [True, False])
def test_metacal_fixnoise_smoke(otype, set_noise_image):
def test_metacal_fixnoise_smoke(otype, set_noise_image, metacal_caching):
rng = np.random.RandomState(seed=100)

obs = _get_obs(rng, noise=0.005, set_noise_image=set_noise_image)
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_metacal_fixnoise_smoke(otype, set_noise_image):


@pytest.mark.parametrize('fixnoise', [True, False])
def test_metacal_fixnoise(fixnoise):
def test_metacal_fixnoise(fixnoise, metacal_caching):
rng = np.random.RandomState(seed=100)

obs = _get_obs(rng, noise=0.005)
Expand All @@ -152,7 +152,7 @@ def test_metacal_fixnoise(fixnoise):
assert mobs.pixels[0]['ierr'] == np.sqrt(obs.weight[0, 0])


def test_metacal_fixnoise_noise_image():
def test_metacal_fixnoise_noise_image(metacal_caching):
rng = np.random.RandomState(seed=100)

obs = _get_obs(rng, noise=0.005, set_noise_image=True)
Expand All @@ -171,7 +171,7 @@ def test_metacal_fixnoise_noise_image():
assert mobs.pixels[0]['ierr'] == np.sqrt(obs.weight[0, 0]/2)


def test_metacal_errors():
def test_metacal_errors(metacal_caching):
rng = np.random.RandomState(seed=100)
obs = _get_obs(rng, noise=0.005, set_noise_image=True)

Expand Down Expand Up @@ -208,6 +208,6 @@ def _do_test_low_psf_s2n():
ngmix.metacal.get_all_metacal(obs=obs, rng=rng, psf='fitgauss')


def test_low_psf_s2n():
def test_low_psf_s2n(metacal_caching):
with pytest.raises(ngmix.BootPSFFailure):
_do_test_low_psf_s2n()
Loading