Skip to content

Commit

Permalink
Consolidate Cubeviz parser conversions (#3221)
Browse files Browse the repository at this point in the history
* Consolidate Cubeviz parser conversions
  • Loading branch information
pllim authored Oct 18, 2024
1 parent 19242dd commit 0ffca43
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 118 deletions.
156 changes: 43 additions & 113 deletions jdaviz/configs/cubeviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from astropy.nddata import StdDevUncertainty
from astropy.time import Time
from astropy.wcs import WCS
from specutils import Spectrum1D, SpectralAxis
from specutils import Spectrum1D

from jdaviz.core.custom_units import PIX2
from jdaviz.core.registries import data_parser_registry
from jdaviz.core.validunits import check_if_unit_is_per_solid_angle
from jdaviz.utils import standardize_metadata, PRIHDR_KEY, download_uri_to_path

from jdaviz.utils import (standardize_metadata, PRIHDR_KEY, download_uri_to_path,
_eqv_flux_to_sb_pixel)

__all__ = ['parse_data']

Expand Down Expand Up @@ -188,23 +188,22 @@ def _return_spectrum_with_correct_units(flux, wcs, metadata, data_type=None,
Also converts flux units to flux/pix2 solid angle units, if `flux` is not a surface
brightness and `apply_pix2` is True.
"""
# handle scale factors when they are included in the unit
# (has to be done before Spectrum1D creation)
if not np.isclose(flux.unit.scale, 1, rtol=1e-5):
flux = flux.to(flux.unit / flux.unit.scale)

with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='Input WCS indicates that the spectral axis is not last',
category=UserWarning)
sc = Spectrum1D(flux=flux, wcs=wcs, meta=metadata, uncertainty=uncertainty, mask=mask)

# convert flux and uncertainty to per-pix2 if input is not a surface brightness
if apply_pix2:
if not check_if_unit_is_per_solid_angle(flux.unit):
flux = flux / PIX2
if uncertainty is not None:
uncertainty = uncertainty / PIX2

# handle scale factors when they are included in the unit
if not np.isclose(flux.unit.scale, 1.0, rtol=1e-5):
flux = flux.to(flux.unit / flux.unit.scale)

sc = Spectrum1D(flux=flux, wcs=wcs, uncertainty=uncertainty, mask=mask)
# convert flux and uncertainty to per-pix2 if input is not a surface brightness
target_flux_unit = None
if (apply_pix2 and (data_type != "mask") and
(not check_if_unit_is_per_solid_angle(flux.unit))):
target_flux_unit = flux.unit / PIX2

if target_wave_unit is None and hdulist is not None:
found_target = False
Expand All @@ -223,23 +222,22 @@ def _return_spectrum_with_correct_units(flux, wcs, metadata, data_type=None,
found_target = True
break

if (data_type == 'flux' and target_wave_unit is not None
and target_wave_unit != sc.spectral_axis.unit):
metadata['_orig_spec'] = sc
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', message='Input WCS indicates that the spectral axis is not last',
category=UserWarning)
new_sc = Spectrum1D(
flux=sc.flux,
spectral_axis=sc.spectral_axis.to(target_wave_unit, u.spectral()),
meta=metadata,
uncertainty=sc.uncertainty,
mask=sc.mask
)
else:
sc.meta = metadata
if target_wave_unit == sc.spectral_axis.unit:
target_wave_unit = None

if (target_wave_unit is None) and (target_flux_unit is None): # Nothing to convert
new_sc = sc
elif target_flux_unit is None: # Convert wavelength only
new_sc = sc.with_spectral_axis_unit(target_wave_unit)
elif target_wave_unit is None: # Convert flux only and only PIX2 stuff
new_sc = sc.with_flux_unit(target_flux_unit, equivalencies=_eqv_flux_to_sb_pixel())
else: # Convert both
new_sc = sc.with_spectral_axis_and_flux_units(
target_wave_unit, target_flux_unit, flux_equivalencies=_eqv_flux_to_sb_pixel())

if target_wave_unit is not None:
new_sc.meta['_orig_spec'] = sc # Need this for later

return new_sc


Expand Down Expand Up @@ -300,7 +298,7 @@ def _parse_hdulist(app, hdulist, file_name=None,
metadata['_orig_spatial_wcs'] = _get_celestial_wcs(wcs)

apply_pix2 = data_type in ['flux', 'uncert']
sc = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type,
sc = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type=data_type,
hdulist=hdulist, apply_pix2=apply_pix2)

app.add_data(sc, data_label)
Expand Down Expand Up @@ -358,7 +356,8 @@ def _parse_jwst_s3d(app, hdulist, data_label, ext='SCI',
if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist:
metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header)

data = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type, hdulist=hdulist)
data = _return_spectrum_with_correct_units(
flux, wcs, metadata, data_type=data_type, hdulist=hdulist)
app.add_data(data, data_label, parent=parent)

# get glue data and update if DQ:
Expand Down Expand Up @@ -418,7 +417,8 @@ def _parse_esa_s3d(app, hdulist, data_label, ext='DATA', flux_viewer_reference_n
# to sky regions, where the parent data of the subset might have dropped spatial WCS info
metadata['_orig_spatial_wcs'] = _get_celestial_wcs(wcs)

data = _return_spectrum_with_correct_units(flux, wcs, metadata, data_type, hdulist=hdulist)
data = _return_spectrum_with_correct_units(
flux, wcs, metadata, data_type=data_type, hdulist=hdulist)

app.add_data(data, data_label)

Expand Down Expand Up @@ -466,12 +466,10 @@ def _parse_spectrum1d_3d(app, file_obj, data_label=None,
if hasattr(file_obj, 'wcs'):
meta['_orig_spatial_wcs'] = _get_celestial_wcs(file_obj.wcs)

s1d = _return_spectrum_with_correct_units(flux, wcs=file_obj.wcs, metadata=meta)

# convert data loaded in flux units to a per-square-pixel surface
# Also convert data loaded in flux units to a per-square-pixel surface
# brightness unit (e.g Jy to Jy/pix**2)
if (attr != "mask") and (not check_if_unit_is_per_solid_angle(flux.unit)):
s1d = convert_spectrum1d_from_flux_to_flux_per_pixel(s1d)
s1d = _return_spectrum_with_correct_units(
flux, file_obj.wcs, meta, data_type=attr, apply_pix2=True)

cur_data_label = app.return_data_label(data_label, attr.upper())
app.add_data(s1d, cur_data_label)
Expand Down Expand Up @@ -502,7 +500,8 @@ def _parse_spectrum1d(app, file_obj, data_label=None, spectrum_viewer_reference_
# convert data loaded in flux units to a per-square-pixel surface
# brightness unit (e.g Jy to Jy/pix**2)
if not check_if_unit_is_per_solid_angle(file_obj.flux.unit):
file_obj = convert_spectrum1d_from_flux_to_flux_per_pixel(file_obj)
file_obj = file_obj.with_flux_unit(
file_obj.flux.unit / PIX2, equivalencies=_eqv_flux_to_sb_pixel())

app.add_data(file_obj, data_label)
app.add_data_to_viewer(spectrum_viewer_reference_name, data_label)
Expand All @@ -522,15 +521,15 @@ def _parse_ndarray(app, file_obj, data_label=None, data_type=None,
flux = file_obj

if not hasattr(flux, 'unit'):
flux = flux << u.count
flux = flux << (u.count / PIX2)

meta = standardize_metadata({'_orig_spatial_wcs': None})
s3d = Spectrum1D(flux=flux, meta=meta)

# convert data loaded in flux units to a per-square-pixel surface
# brightness unit (e.g Jy to Jy/pix**2)
if not check_if_unit_is_per_solid_angle(s3d.unit):
file_obj = convert_spectrum1d_from_flux_to_flux_per_pixel(s3d)
if not check_if_unit_is_per_solid_angle(s3d.flux.unit):
s3d = s3d.with_flux_unit(s3d.flux.unit / PIX2, equivalencies=_eqv_flux_to_sb_pixel())

app.add_data(s3d, data_label)

Expand All @@ -556,12 +555,7 @@ def _parse_gif(app, file_obj, data_label=None, flux_viewer_reference_name=None):
flux = np.rot90(np.moveaxis(flux, 0, 2), k=-1, axes=(0, 1))

meta = {'filename': file_name, '_orig_spatial_wcs': None}
s3d = Spectrum1D(flux=flux * u.count, meta=standardize_metadata(meta))

# convert data loaded in flux units to a per-square-pixel surface
# brightness unit (e.g Jy to Jy/pix**2)
if not check_if_unit_is_per_solid_angle(s3d):
file_obj = convert_spectrum1d_from_flux_to_flux_per_pixel(s3d)
s3d = Spectrum1D(flux=flux * (u.count / PIX2), meta=standardize_metadata(meta))

app.add_data(s3d, data_label)
app.add_data_to_viewer(flux_viewer_reference_name, data_label)
Expand All @@ -580,67 +574,3 @@ def _get_data_type_by_hdu(hdu):
else:
data_type = ''
return data_type


def convert_spectrum1d_from_flux_to_flux_per_pixel(spectrum):
"""
Converts a Spectrum1D object's flux units to flux per square pixel.
This function takes a `specutils.Spectrum1D` object with flux units and converts the
flux (and optionally, uncertainty) to a surface brightness per square pixel
(e.g., from Jy to Jy/pix**2). This is done by updating the units of spectrum.flux
and (if present) spectrum.uncertainty, and creating a new `specutils.Spectrum1D`
object with the modified flux and uncertainty.
Parameters
----------
spectrum : Spectrum1D
A `specutils.Spectrum1D` object containing flux data, which is assumed to be in
flux units without any angular component in the denominator.
Returns
-------
Spectrum1D
A new `specutils.Spectrum1D` object with flux and uncertainty (if present)
converted to units of flux per square pixel.
"""

# convert flux, which is always populated
flux = getattr(spectrum, 'flux')
flux = flux / PIX2

# and uncerts, if present
uncerts = getattr(spectrum, 'uncertainty')
if uncerts is not None:
# enforce common uncert type.
uncerts = uncerts.represent_as(StdDevUncertainty)
uncerts = StdDevUncertainty(uncerts.quantity / PIX2)

# create a new spectrum 1d with all the info from the input spectrum 1d,
# and the flux / uncerts converted from flux to SB per square pixel

# if there is a spectral axis that is a SpectralAxis, you cant also set
# redshift or radial_velocity
spectral_axis = getattr(spectrum, 'spectral_axis', None)
if spectral_axis is not None:
if isinstance(spectral_axis, SpectralAxis):
redshift = None
radial_velocity = None
else:
redshift = spectrum.redshift
radial_velocity = spectrum.radial_velocity

# initialize new spectrum1d with new flux, uncerts, and all other init parameters
# from old input spectrum as well as any 'meta'. any more missing information
# not in init signature that might be present in `spectrum`?
new_spec1d = Spectrum1D(flux=flux, uncertainty=uncerts,
spectral_axis=spectrum.spectral_axis,
mask=spectrum.mask,
wcs=spectrum.wcs,
velocity_convention=spectrum.velocity_convention,
rest_value=spectrum.rest_value, redshift=redshift,
radial_velocity=radial_velocity,
bin_specification=getattr(spectrum, 'bin_specification', None),
meta=spectrum.meta)

return new_spec1d
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ def _return_extracted(self, cube, wcs, collapsed_nddata):
uncertainty = collapsed_nddata.uncertainty

collapsed_spec = _return_spectrum_with_correct_units(
flux, wcs, collapsed_nddata.meta, 'flux',
flux, wcs, collapsed_nddata.meta, data_type='flux',
target_wave_unit=target_wave_unit,
uncertainty=uncertainty,
mask=mask
Expand Down
4 changes: 2 additions & 2 deletions jdaviz/configs/cubeviz/plugins/tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,15 +194,15 @@ def test_numpy_cube(cubeviz_helper):
assert data.label == 'Array'
assert data.shape == (4, 3, 2) # x, y, z
assert isinstance(data.coords, PaddedSpectrumWCS)
assert flux.units == 'ct'
assert flux.units == 'ct / pix2'

# Check context of second cube.
data = cubeviz_helper.app.data_collection[1]
flux = data.get_component('flux')
assert data.label == 'uncert_array'
assert data.shape == (4, 3, 2) # x, y, z
assert isinstance(data.coords, PaddedSpectrumWCS)
assert flux.units == 'ct'
assert flux.units == 'ct / pix2'


def test_invalid_data_types(cubeviz_helper):
Expand Down
6 changes: 4 additions & 2 deletions jdaviz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,9 +549,11 @@ def _eqv_flux_to_sb_pixel():

# generate an equivalency for each flux type that would need
# another equivalency for converting to/from
flux_units = [u.MJy, u.erg / (u.s * u.cm**2 * u.Angstrom),
flux_units = [u.MJy,
u.erg / (u.s * u.cm**2 * u.Angstrom),
u.ph / (u.Angstrom * u.s * u.cm**2),
u.ph / (u.Hz * u.s * u.cm**2)]
u.ph / (u.Hz * u.s * u.cm**2),
u.ct]
return [(flux_unit, flux_unit / PIX2, lambda x: x, lambda x: x)
for flux_unit in flux_units]

Expand Down

0 comments on commit 0ffca43

Please sign in to comment.