Skip to content

Commit

Permalink
Merge pull request #798 from Cadair/maybe_this_will_work
Browse files Browse the repository at this point in the history
Reduce memory consuption in axis_world_coords (again)
  • Loading branch information
nabobalis authored Jan 13, 2025
2 parents 067ddf8 + 1a5bd72 commit 1c9faa1
Show file tree
Hide file tree
Showing 7 changed files with 137 additions and 84 deletions.
1 change: 0 additions & 1 deletion changelog/780.bugfix.rst

This file was deleted.

1 change: 1 addition & 0 deletions changelog/798.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added an internal code to shortcut non-correlated axes avoiding the creation of a full coordinate grid, reducing memory use in specific circumstances.
105 changes: 105 additions & 0 deletions ndcube/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,86 @@ def wcs_3d_ln_lt_t_rotated():
return WCS(header=h_rotated)


@pytest.fixture
def wcs_3d_ln_lt_l_coupled():
# WCS for a 3D data cube with two celestial axes and one wavelength axis.
# The latitudinal dimension is coupled to the third pixel dimension through
# a single off diagonal element in the PCij matrix
header = {
'CTYPE1': 'HPLN-TAN',
'CRPIX1': 5,
'CDELT1': 5,
'CUNIT1': 'arcsec',
'CRVAL1': 0.0,

'CTYPE2': 'HPLT-TAN',
'CRPIX2': 5,
'CDELT2': 5,
'CUNIT2': 'arcsec',
'CRVAL2': 0.0,

'CTYPE3': 'WAVE',
'CRPIX3': 1.0,
'CDELT3': 1,
'CUNIT3': 'Angstrom',
'CRVAL3': 1.0,

'PC1_1': 1,
'PC1_2': 0,
'PC1_3': 0,
'PC2_1': 0,
'PC2_2': 1,
'PC2_3': -1.0,
'PC3_1': 0.0,
'PC3_2': 0.0,
'PC3_3': 1.0,

'WCSAXES': 3,

'DATEREF': "2020-01-01T00:00:00"
}
return WCS(header=header)


@pytest.fixture
def wcs_3d_ln_lt_t_coupled():
# WCS for a 3D data cube with two celestial axes and one time axis.
header = {
'CTYPE1': 'HPLN-TAN',
'CRPIX1': 5,
'CDELT1': 5,
'CUNIT1': 'arcsec',
'CRVAL1': 0.0,

'CTYPE2': 'HPLT-TAN',
'CRPIX2': 5,
'CDELT2': 5,
'CUNIT2': 'arcsec',
'CRVAL2': 0.0,

'CTYPE3': 'UTC',
'CRPIX3': 1.0,
'CDELT3': 1,
'CUNIT3': 's',
'CRVAL3': 1.0,

'PC1_1': 1,
'PC1_2': 0,
'PC1_3': 0,
'PC2_1': 0,
'PC2_2': 1,
'PC2_3': 0,
'PC3_1': 0,
'PC3_2': 1,
'PC3_3': 1,

'WCSAXES': 3,

'DATEREF': "2020-01-01T00:00:00"
}
return WCS(header=header)


################################################################################
# Extra and Global Coords Fixtures
################################################################################
Expand Down Expand Up @@ -519,6 +599,31 @@ def ndcube_3d_rotated(wcs_3d_ln_lt_t_rotated, simple_extra_coords_3d):
return cube


@pytest.fixture
def ndcube_3d_coupled(wcs_3d_ln_lt_l_coupled):
shape = (128, 256, 512)
wcs_3d_ln_lt_l_coupled.array_shape = shape
data = data_nd(shape)
mask = data > 0
return NDCube(
data,
wcs_3d_ln_lt_l_coupled,
mask=mask,
uncertainty=data,
)


@pytest.fixture
def ndcube_3d_coupled_time(wcs_3d_ln_lt_t_coupled):
shape = (128, 256, 512)
wcs_3d_ln_lt_t_coupled.array_shape = shape
data = data_nd(shape)
return NDCube(
data,
wcs_3d_ln_lt_t_coupled,
)


@pytest.fixture
def ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l):
return gen_ndcube_3d_l_ln_lt_ectime(wcs_3d_lt_ln_l,
Expand Down
87 changes: 9 additions & 78 deletions ndcube/ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from astropy.units import UnitsError
from astropy.wcs.utils import _split_matrix

from ndcube.utils.wcs import world_axis_to_pixel_axes

try:
# Import sunpy coordinates if available to register the frames and WCS functions with astropy
import sunpy.coordinates # NOQA
Expand Down Expand Up @@ -486,47 +484,9 @@ def quantity(self):
"""Unitful representation of the NDCube data."""
return u.Quantity(self.data, self.unit, copy=_NUMPY_COPY_IF_NEEDED)

def _generate_independent_world_coords(self, pixel_corners, wcs, needed_axes, units):
"""
Generate world coordinates for independent axes.
The idea is to workout only the specific grid that is needed for independent axes.
This speeds up the calculation of world coordinates and reduces memory usage.
Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The required pixel axes.
units : bool
If units are needed.
Returns
-------
array-like
The world coordinates.
"""
needed_axes = np.array(needed_axes).squeeze()
if self.data.ndim in needed_axes:
required_axes = needed_axes - 1
else:
required_axes = needed_axes
lims = (-0.5, self.data.shape[::-1][required_axes] + 1) if pixel_corners else (0, self.data.shape[::-1][required_axes])
indices = [np.arange(lims[0], lims[1]) if wanted else [0] for wanted in wcs.axis_correlation_matrix[required_axes]]
world_coords = wcs.pixel_to_world_values(*indices)
if units:
world_coords = world_coords << u.Unit(wcs.world_axis_units[needed_axes])
return world_coords

def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, units):
def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
"""
Generate world coordinates for dependent axes.
This will work out the exact grid that is needed for dependent axes
and can be time and memory consuming.
Private method to generate world coordinates.
Parameters
----------
Expand All @@ -535,7 +495,7 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The required pixel axes.
The axes that are needed.
units : bool
If units are needed.
Expand Down Expand Up @@ -573,6 +533,12 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit
# And inject 0s for those coordinates
for idx in non_corr_axes:
sub_range.insert(idx, 0)
# If we are subsetting world axes, ignore any pixel axes which are not correlated with our requested world axis.
if any(world_axis in needed_axes for world_axis in world_axes_indices):
needed_pixel_axes = wcs.axis_correlation_matrix[needed_axes]
unneeded_pixel_axes = np.argwhere(needed_pixel_axes.sum(axis=0) == 0)[:, 0]
for idx in unneeded_pixel_axes:
sub_range[idx] = 0
# Generate a grid of broadcastable pixel indices for all pixel dimensions
grid = np.meshgrid(*sub_range, indexing='ij')
# Convert to world coordinates
Expand All @@ -592,41 +558,6 @@ def _generate_dependent_world_coords(self, pixel_corners, wcs, needed_axes, unit
world_coords[i] = coord << u.Unit(unit)
return world_coords

def _generate_world_coords(self, pixel_corners, wcs, *, needed_axes, units=None):
"""
Private method to generate world coordinates.
Handles both dependent and independent axes.
Parameters
----------
pixel_corners : bool
If one needs pixel corners, otherwise pixel centers.
wcs : astropy.wcs.WCS
The WCS.
needed_axes : array-like
The axes that are needed.
units : bool
If units are needed.
Returns
-------
array-like
The world coordinates.
"""
axes_are_independent = []
pixel_axes = set()
for world_axis in needed_axes:
pix_ax = world_axis_to_pixel_axes(world_axis, wcs.axis_correlation_matrix)
axes_are_independent.append(len(pix_ax) == 1)
pixel_axes = pixel_axes.union(set(pix_ax))
pixel_axes = list(pixel_axes)
if all(axes_are_independent) and len(pixel_axes) == len(needed_axes) and len(needed_axes) != 0:
world_coords = self._generate_independent_world_coords(pixel_corners, wcs, needed_axes, units)
else:
world_coords = self._generate_dependent_world_coords(pixel_corners, wcs, needed_axes, units)
return world_coords

@utils.cube.sanitize_wcs
def axis_world_coords(self, *axes, pixel_corners=False, wcs=None):
# Docstring in NDCubeABC.
Expand Down
18 changes: 16 additions & 2 deletions ndcube/tests/test_ndcube.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,20 @@ def test_axis_world_coords_wave_ec(ndcube_3d_l_ln_lt_ectime):
assert coords[0].shape == (5,)


@pytest.mark.limit_memory("12 MB")
def test_axis_world_coords_wave_coupled_dims(ndcube_3d_coupled):
cube = ndcube_3d_coupled

cube.axis_world_coords('em.wl')


@pytest.mark.limit_memory("12 MB")
def test_axis_world_coords_time_coupled_dims(ndcube_3d_coupled_time):
cube = ndcube_3d_coupled_time

cube.axis_world_coords('time')


def test_axis_world_coords_empty_ec(ndcube_3d_l_ln_lt_ectime):
cube = ndcube_3d_l_ln_lt_ectime
sub_cube = cube[:, 0]
Expand Down Expand Up @@ -292,10 +306,10 @@ def test_axis_world_coords_single_pixel_corners(axes, ndcube_3d_ln_lt_l):
assert u.allclose(coords[0], [1.02e-09, 1.04e-09, 1.06e-09, 1.08e-09] * u.m)

coords = ndcube_3d_ln_lt_l.axis_world_coords_values(*axes, pixel_corners=True)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)

coords = ndcube_3d_ln_lt_l.axis_world_coords(*axes, pixel_corners=True)
assert u.allclose(coords[0], [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09, 1.11e-09] * u.m)
assert u.allclose(coords, [1.01e-09, 1.03e-09, 1.05e-09, 1.07e-09, 1.09e-09] * u.m)


@pytest.mark.parametrize(("ndc", "item"),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ tests = [
"pytest-mpl>=0.12",
"pytest-xdist",
"pytest",
"pytest-memray; sys_platform != 'win32'",
"scipy",
"specutils",
"sunpy>=5.0.0",
Expand Down
8 changes: 5 additions & 3 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ addopts =
--doctest-continue-on-failure
mpl-results-path = figure_test_images
mpl-use-full-test-name = true
remote_data_strict = True
doctest_subpackage_requires =
docs/explaining_ndcube/* = numpy>=2.0.0
markers =
limit_memory: pytest-memray marker to fail a test if too much memory used
filterwarnings =
# Turn all warnings into errors so they do not pass silently.
error
Expand All @@ -53,6 +58,3 @@ filterwarnings =
ignore:FigureCanvasAgg is non-interactive, and thus cannot be shown:UserWarning
# Oldestdeps from gWCS
ignore:pkg_resources is deprecated as an API:DeprecationWarning
remote_data_strict = True
doctest_subpackage_requires =
docs/explaining_ndcube/* = numpy>=2.0.0

0 comments on commit 1c9faa1

Please sign in to comment.