Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AL-837: Calculate output WCS for resample from already-known s_region #307

Merged
merged 8 commits into from
Oct 29, 2024
1 change: 1 addition & 0 deletions changes/307.apichange.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make wcs_from_footprints accept s_regions instead of wcs objects
186 changes: 68 additions & 118 deletions src/stcal/alignment/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import logging
import re
from typing import TYPE_CHECKING

if TYPE_CHECKING:
Expand All @@ -15,7 +16,6 @@
from astropy import wcs as fitswcs
from astropy.coordinates import SkyCoord
from astropy.modeling import models as astmodels
from astropy.utils.misc import isiterable
from gwcs.wcstools import wcs_from_fiducial

log = logging.getLogger(__name__)
Expand All @@ -28,6 +28,7 @@
"compute_s_region_imaging",
"compute_s_region_keyword",
"wcs_from_footprints",
"wcs_bbox_from_shape",
"reproject",
]

Expand All @@ -41,15 +42,15 @@
Parameters
----------
spatial_footprint : np.ndarray
A 2xN array containing the world coordinates of the WCS footprint's
An Nx2 array containing the world coordinates of the WCS footprint's
bounding box, where N is the number of bounding box positions.

Returns
-------
lon_fiducial, lat_fiducial : np.ndarray, np.ndarray
The world coordinates of the fiducial point in the output coordinate frame.
"""
lon, lat = spatial_footprint
lon, lat = spatial_footprint.T
lon, lat = np.deg2rad(lon), np.deg2rad(lat)
x = np.cos(lat) * np.cos(lon)
y = np.cos(lat) * np.sin(lon)
Expand Down Expand Up @@ -139,15 +140,16 @@
return transform


def _get_axis_min_and_bounding_box(wcs_list: list[gwcs.wcs.WCS],
def _get_axis_min_and_bounding_box(footprints: list[np.ndarray],
ref_wcs: gwcs.wcs.WCS) -> tuple:
"""
Calculates axis minimum values and bounding box.

Parameters
----------
wcs_list : list
The list of WCS objects.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.

ref_wcs : ~gwcs.wcs.WCS
The reference WCS object.
Expand All @@ -160,8 +162,7 @@
2 - a tuple containing the bounding box region in the format
((x0_lower, x0_upper), (x1_lower, x1_upper)).
"""
footprints = [w.footprint().T for w in wcs_list]
domain_bounds = np.hstack([ref_wcs.backward_transform(*f) for f in footprints])
domain_bounds = np.hstack([ref_wcs.backward_transform(*f.T) for f in footprints])
axis_min_values = np.min(domain_bounds, axis=1)
domain_bounds = (domain_bounds.T - axis_min_values).T

Expand All @@ -177,24 +178,17 @@
return (axis_min_values, output_bounding_box)


def _calculate_fiducial(wcs_list: list[gwcs.wcs.WCS],
bounding_box: Sequence | None,
crval: Sequence | None = None) -> np.ndarray:
def _calculate_fiducial(footprints: list[np.ndarray],
crval: Sequence | None = None) -> tuple:
"""
Calculates the coordinates of the fiducial point and, if necessary, updates it with
the values in CRVAL (the update is applied to spatial axes only).

Parameters
----------
wcs_list : list
A list of WCS objects.

bounding_box : tuple, or list
The bounding box over which the WCS is valid. It can be a either tuple of tuples
or a list of lists of size 2 where each element represents a range of
(low, high) values. The bounding_box is in the order of the axes, axes_order.
For two inputs and axes_order(0, 1) the bounding box can be either
((xlow, xhigh), (ylow, yhigh)) or [[xlow, xhigh], [ylow, yhigh]].
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.

crval : list, optional
A reference world coordinate associated with the reference pixel. If not `None`,
Expand All @@ -203,21 +197,15 @@

Returns
-------
fiducial : np.ndarray
A two-elements array containing the world coordinate of the fiducial point.
fiducial : tuple
A tuple containing the world coordinate of the fiducial point.
"""
fiducial = compute_fiducial(wcs_list, bounding_box=bounding_box)
if crval is not None:
i = 0
for k, axt in enumerate(wcs_list[0].output_frame.axes_type):
if axt == "SPATIAL":
# overwrite only spatial axes with user-provided CRVAL
fiducial[k] = crval[i]
i += 1
return fiducial
return tuple(crval)

Check warning on line 204 in src/stcal/alignment/util.py

View check run for this annotation

Codecov / codecov/patch

src/stcal/alignment/util.py#L204

Added line #L204 was not covered by tests
return compute_fiducial(footprints)


def _calculate_offsets(fiducial: np.ndarray,
def _calculate_offsets(fiducial: tuple,
wcs: gwcs.wcs.WCS | None,
axis_min_values: np.ndarray | None,
crpix: Sequence | None) -> astmodels.Model:
Expand All @@ -226,8 +214,8 @@

Parameters
----------
fiducial : np.ndarray
A two-elements containing the world coordinates of the fiducial point.
fiducial : tuple
A tuple containing the world coordinates of the fiducial point.

wcs : ~gwcs.wcs.WCS
A WCS object. It will be used to determine the
Expand Down Expand Up @@ -265,13 +253,14 @@

def _calculate_new_wcs(wcs: gwcs.wcs.WCS,
shape: Sequence | None,
wcs_list: list[gwcs.wcs.WCS],
fiducial: np.ndarray,
footprints: list[np.ndarray],
fiducial: tuple,
crpix: Sequence | None = None,
transform: astmodels.Model | None = None,
) -> gwcs.wcs.WCS:
"""
Calculates a new WCS object based on the combined WCS objects provided.
Calculates a new WCS object based on the combined footprints
and reference WCS provided.

Parameters
----------
Expand All @@ -282,11 +271,12 @@
The shape of the new WCS's pixel grid. If `None`, then the output bounding box
will be used to determine it.

wcs_list : list
A list containing WCS objects.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.

fiducial : np.ndarray
A two-elements array containing the location on the sky in some standard
fiducial : tuple
A tuple containing the location on the sky in some standard
coordinate system.

crpix : tuple, optional
Expand All @@ -309,7 +299,7 @@
transform=transform,
input_frame=wcs.input_frame,
)
axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(wcs_list, wcs_new)
axis_min_values, output_bounding_box = _get_axis_min_and_bounding_box(footprints, wcs_new)
offsets = _calculate_offsets(
fiducial=fiducial,
wcs=wcs_new,
Expand All @@ -328,43 +318,6 @@
return wcs_new


def _validate_wcs_list(wcs_list: list[gwcs.wcs.WCS]) -> bool:
"""
Validates wcs_list.

Parameters
----------
wcs_list : list
A list of WCS objects.

Returns
-------
bool or Exception
If wcs_list is valid, returns True. Otherwise, it will raise an error.

Raises
------
ValueError
Raised whenever wcs_list is not an iterable.
TypeError
Raised whenever wcs_list is empty or any of its content is not an
instance of WCS.
"""
if not isiterable(wcs_list):
msg = "Expected 'wcs_list' to be an iterable of WCS objects."
raise ValueError(msg)

if len(wcs_list):
if not all(isinstance(w, gwcs.WCS) for w in wcs_list):
msg = "All items in 'wcs_list' are to be instances of gwcs.wcs.WCS."
raise TypeError(msg)
else:
msg = "'wcs_list' should not be empty."
raise TypeError(msg)

return True


def compute_scale(
wcs: gwcs.wcs.WCS,
fiducial: tuple | np.ndarray,
Expand Down Expand Up @@ -430,18 +383,17 @@
return float(np.sqrt(xscale * yscale))


def compute_fiducial(wcslist: list,
bounding_box: Sequence | None = None) -> np.ndarray:
def compute_fiducial(footprints: list[np.ndarray]) -> tuple:
"""
Calculates the world coordinates of the fiducial point of a list of WCS objects.
For a celestial footprint this is the center. For a spectral footprint, it is the
beginning of its range.

Parameters
----------
wcslist : list
A list containing all the WCS objects for which the fiducial is to be
calculated.
footprints : list
A list of numpy arrays each of shape (N, 2) containing the
(RA, Dec) vertices demarcating the footprint of the input WCSs.

bounding_box : tuple, list, None
The bounding box over which the WCS is valid. It can be a either tuple of tuples
Expand All @@ -452,27 +404,16 @@

Returns
-------
fiducial : np.ndarray
A two-elements array containing the world coordinates of the fiducial point
fiducial : tuple
A tuple containing the world coordinates of the fiducial point
in the combined output coordinate frame.

Notes
-----
This function assumes all WCSs have the same output coordinate frame.
"""
axes_types = wcslist[0].output_frame.axes_type
spatial_axes = np.array(axes_types) == "SPATIAL"
spectral_axes = np.array(axes_types) == "SPECTRAL"
footprints = np.hstack([w.footprint(bounding_box=bounding_box).T for w in wcslist])
spatial_footprint = footprints[spatial_axes]
spectral_footprint = footprints[spectral_axes]

fiducial = np.empty(len(axes_types))
if spatial_footprint.any():
fiducial[spatial_axes] = _calculate_fiducial_from_spatial_footprint(spatial_footprint)
if spectral_footprint.any():
fiducial[spectral_axes] = spectral_footprint.min()
return fiducial
spatial_footprint = np.vstack(footprints)
return _calculate_fiducial_from_spatial_footprint(spatial_footprint)


def calc_rotation_matrix(roll_ref: float, v3i_yangle: float, vparity: int = 1) -> list[float]:
Expand Down Expand Up @@ -519,12 +460,27 @@
return [pc1_1, pc1_2, pc2_1, pc2_2]


def _sregion_to_footprint(s_region: str) -> np.ndarray:
"""
Parameters
----------
s_region : str
The S_REGION header keyword

Returns
-------
footprint : np.array
A 2D array of the footprint of the region, shape (N, 2)
"""
no_prefix = re.sub(r"[a-zA-Z]", "", s_region)
return np.array(no_prefix.split(), dtype=float).reshape(-1, 2)


def wcs_from_footprints(
wcs_list: list[gwcs.wcs.WCS],
footprints: list[np.ndarray] | list[str],
ref_wcs: gwcs.wcs.WCS,
ref_wcsinfo: dict,
transform: astropy.modeling.models.Model | None = None,
bounding_box: Sequence | None = None,
pscale_ratio: float | None = None,
pscale: float | None = None,
rotation: float | None = None,
Expand All @@ -549,13 +505,15 @@

Parameters
----------
wcs_list : list
A list of valid datamodels.
footprints : list of np.ndarray or list of str
If list elements are numpy arrays, each should have shape (N, 2) and contain
(RA, Dec) vertices demarcating the footprint of the input WCSs.
If list elements are strings, each should be the S_REGION header keyword
containing (RA, Dec) vertices demarcating the footprint of the input WCSs.

ref_wcs :
A valid datamodel whose WCS is used as reference for the creation of the output
A WCS used as reference for the creation of the output
coordinate frame, projection, and scaling and rotation transforms.
If not supplied the first model in the list is used as ``refmodel``.

ref_wcsinfo : dict
A dictionary containing the WCS FITS keywords and corresponding values.
Expand All @@ -564,10 +522,6 @@
A transform, passed to :py:func:`gwcs.wcstools.wcs_from_fiducial`
If not supplied `Scaling | Rotation` is computed from ``refmodel``.

bounding_box : tuple
Bounding_box of the new WCS.
If not supplied it is computed from the bounding_box of all inputs.

pscale_ratio : float, None
Ratio of input to output pixel scale. Ignored when either
``transform`` or ``pscale`` are provided.
Expand Down Expand Up @@ -600,25 +554,21 @@
Right ascension and declination of the reference pixel. Automatically
computed if not provided.

wcs_list : list
A list of WCS objects. If not supplied, the WCS objects are extracted
from the input datamodels.

Returns
-------
wcs_new : ~gwcs.wcs.WCS
The WCS object corresponding to the combined input footprints.

"""
_validate_wcs_list(wcs_list)

fiducial = _calculate_fiducial(wcs_list=wcs_list, bounding_box=bounding_box, crval=crval)

ref_wcs = wcs_list[0] if ref_wcs is None else ref_wcs
footprints = [_sregion_to_footprint(s_region)
if isinstance(s_region, str) else s_region
for s_region in footprints]
fiducial = _calculate_fiducial(footprints, crval=crval)

transform = _generate_tranform(
ref_wcs,
wcsinfo=ref_wcsinfo,
ref_wcsinfo,
pscale_ratio=pscale_ratio,
pscale=pscale,
rotation=rotation,
Expand All @@ -630,7 +580,7 @@
wcs=ref_wcs,
shape=shape,
crpix=crpix,
wcs_list=wcs_list,
footprints=footprints,
fiducial=fiducial,
transform=transform,
)
Expand Down
2 changes: 1 addition & 1 deletion src/stcal/tweakreg/astrometric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def create_astrometric_catalog(

def compute_radius(wcs):
"""Compute the radius from the center to the furthest edge of the WCS."""
fiducial = compute_fiducial([wcs], wcs.bounding_box)
fiducial = compute_fiducial([wcs.footprint(),])
img_center = SkyCoord(
ra=fiducial[0] * u.degree,
dec=fiducial[1] * u.degree)
Expand Down
Loading
Loading